transformers
3ee72af6 - Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 (#37332)

Commit
187 days ago
Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 (#37332) * Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 * fix code format * add test; replace position_ids with query_states becasue position_ids.shape[0] is always 1 * add assert loss is not nan
Author
Parents
Loading