Dropout support for memory efficient attention (#102038)
# Summary
This PR builds off of:
- https://github.com/pytorch/pytorch/pull/101847
- https://github.com/pytorch/pytorch/pull/100583
It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made:
- Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention
- Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support
- Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. https://github.com/facebookresearch/xformers/pull/755
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102038
Approved by: https://github.com/cpuhrsch