pytorch
28d69d52 - Adding Backward Support for NestedTensors and FlashAttention (#97485)

Commit
1 year ago
Adding Backward Support for NestedTensors and FlashAttention (#97485) # Summary <!-- copilot:summary --> ### <samp>🤖 Generated by Copilot at 318764f</samp> This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`. <!-- copilot:poem --> ### <samp>🤖 Generated by Copilot at ed4a773</samp> > _Fused kernels of doom, unleash the flash attention_ > _Nested tensors on fire, reshape and pad with caution_ > _Backward pass of power, dispatch the CUDA key_ > _Test the gradients of hell, warn the user if they disagree_ Pull Request resolved: https://github.com/pytorch/pytorch/pull/97485 Approved by: https://github.com/jbschlosser
Author
Committer
Parents
Loading