pytorch
606fb882 - Dropout support for memory efficient attention (#102038)

Commit
1 year ago
Dropout support for memory efficient attention (#102038) # Summary This PR builds off of: - https://github.com/pytorch/pytorch/pull/101847 - https://github.com/pytorch/pytorch/pull/100583 It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made: - Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention - Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support - Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. https://github.com/facebookresearch/xformers/pull/755 Pull Request resolved: https://github.com/pytorch/pytorch/pull/102038 Approved by: https://github.com/cpuhrsch
Author
Committer
Parents
Loading