Fix graph break in torch.compile when using FA2 with attention_mask=None and batch size > 1 #37332
Fix graph break in torch.compile when using FA2 with attention_mask=N…
422e3b3b
efsotr
marked this pull request as ready for review 266 days ago
fix code format
69a2b038
Merge remote-tracking branch 'upstream/main' into fa2_compile_graph_b…
d9424933
add test; replace position_ids with query_states becasue position_ids…
1dff2564
Merge remote-tracking branch 'upstream/main' into fa2_compile_graph_b…
812cb5a5
Merge branch 'main' into fa2_compile_graph_break
a381b48d
Merge branch 'main' into fa2_compile_graph_break
b1c101bf
Merge branch 'main' into fa2_compile_graph_break
7cb8d411
add assert loss is not nan
5028afc6
Merge branch 'main' into fa2_compile_graph_break
ba3ad0c3
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub