add FlashAttentionKwargs and seq_idx to flat collator (#36456)
* add flash attn kwargs to flattening collator
* add return_seq_idx option
* doc string edits
* cleaner max len updates
* various fixes
* temp testing code
* return int32 seq_idx and FlashAttnKwargs
* DataCollatorIntegrationTest impl
* fix batch dims and dtypes
* fill out remaining collator tests
* test name change and fmt
* rm unused var
* fmt
* minor change
* fmt
* add missing pos_ids check
* consistent {np,pt,tf} tests
* split pt tests into 3, like np/tf tests
* mv comment, rename fa test
* remove batch dim comment
* simply wrapping
* compute cu_seq_len/max_length once
* fmt
* remove tf code
* rm warning
* move separator_id back to 2nd pos
* use cleaner lists in tests
* ret -> batch
* fmt
* attr ordering
* use py ints for max_length_{k,q}