pytorch
9613395e - [SDPA] Integrating the main branch of flash_attn instead of cutlass (#91994)

Commit
1 year ago
[SDPA] Integrating the main branch of flash_attn instead of cutlass (#91994) ### Background Early on in this process of integrating the FlashAttention code into core we were speaking with Tri and we came to the conclusion that the main branch of Flash Attention wasn't suitable for integration. We instead went with a [refactored version](https://github.com/HazyResearch/flash-attention/tree/cutlass) that more heavily depended upon cutlass. That is the current version of FlashAttention in PyTorch. However there are some limitations with that branch. - No backward support for SDPA - Not as performant for some large MHA setups. ### Sumary This PR pulls in the latest version of the main branch of [FlashAttention](https://github.com/HazyResearch/flash-attention/tree/main). It does not register the backward for the aten function SDPA_flash_attn. That will be done in a follow up PR. ### Changeset A few changes were made to the original code for PyTorch. - Flattened one layer of folder structure. (This is to match the the existing FlashAttention in core structure) - Remove return_softmax param and change mha_fwd signature. Since the SDPA in core public function does not support need_weights we remove this argument. - Add a lot of `#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >=530` around sections of code that will not compile for architecture less or equal to 520. Most of these blocks of code are half based asm or _hmul2 operations. An example update ```cpp #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >=530 float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; #else assert(false); return 0; #endif } ``` - Remove any blocksparse functions and files. And comment out utility functions that are used in the blockspase kernels written for FlashAttention since we did not pull in those functions. - Update gemm_cl in **/gemm.h to: ``` c++ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; #elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; #else assert(0); // THIS IS NOT CORRECT BUT THE ASSERT WILL STOP THIS using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; // TD [2022-06-02] We don't support Volta (SM70) yet. #endif ``` ### Reasoning: FlashAttention is only designed to run on gpus that support sm7.5 or later. However PyTorch is generally build and released using `TORCH_CUDA_ARCH_LIST=5.2,..,8.6`. This means that source code must be compilable for these lower archs even if it is not run. But how are we sure that it won't be run? That should be handled by the runtime dispatch mechanism, specifically here: [check_arch](https://github.com/pytorch/pytorch/blob/d70ed68162521341060b06985620cdbef04a8fa9/aten/src/ATen/native/transformers/cuda/sdp_utils.h#L308) There is however one edge case for building from source: User specifies TORCH_CUDA_ARCH_LIST={something less than 7.5} and they are running on a gpu that is >= 7.5 This will cause the runtime dispatcher to think it is okay to run FlashAttention even though the compiled code is bogus. I tested this with arch=5.3 on an a100 and get the following result:` RuntimeError: CUDA error: no kernel image is available for execution on the device` coming from torch.rand. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91994 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading