pytorch
a37e22de - Add Flash Attention support on ROCM (#121561)

Commit
364 days ago
Add Flash Attention support on ROCM (#121561) This patch addresses the major limitations in our previous [PR #115981](https://github.com/pytorch/pytorch/pull/115981) through the new dedicated repository [AOTriton](https://github.com/ROCm/aotriton) - [x] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`). * MI300X is supported. More architectures will be added once Triton support them. - [x] Only supports power of two sequence lengths. * Now it support arbitrary sequence length - [ ] No support for varlen APIs. * varlen API will be supported in the next release of AOTriton - [x] Only support head dimension 16,32,64,128. * Now it support arbitrary head dimension <= 256 - [x] Performance is still being optimized. * Kernel is selected according to autotune information from Triton. Other improvements from AOTriton include * Allow more flexible Tensor storage layout * More flexible API This is a more extensive fix to #112997 Pull Request resolved: https://github.com/pytorch/pytorch/pull/121561 Approved by: https://github.com/malfet, https://github.com/atalman
Author
Committer
Parents
Loading