pytorch
1f8177de - [Inductor][CPU] fix flash attention last_stride!=1 issue (#122083)

Commit
1 year ago
[Inductor][CPU] fix flash attention last_stride!=1 issue (#122083) Fixes #121174. Conv converts the input of sdpa to channel last, resulting in accuracy issue. Ensure the layout in lowering. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122083 Approved by: https://github.com/eellison, https://github.com/jgong5
Author
Committer
Parents
Loading