[WIP] [ATen] Add native_multi_attention_self_attention CPU + GPU implementation (#70649)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70649
As described in https://fb.quip.com/oxpiA1uDBjgP
This implements the first parts of the RFC, and is a rough draft showing the approach. The idea is that for the first cut we can maintain very close (identical I believe in this diff) numerical equivalence to the existing nn.MHA implementation, which is what this diff attempts to do. In subsequent implementations, once we have a working and adopted native self-attention implementation, we could then explore alternative implementations, etc.
The current implementation is similar to existing dedicated implementations such as LightSeq/FasterTransformer/DeepSpeed, and for MHA on both CPUs and GPUs is between 1.2x and 2x faster depending on the setting. It makes some approximations/restrictions (doesn't handle masking in masked softmax, etc), but these shouldn't materially impact performance.
This does the first few items:
* add native_multi_head_attention(...) , native_multi_head_attention_backward(..) to native_functions.yaml
* Implement native_multi_head_attention(..) on GPU, extracting bits and pieces out of LS/DS/FT as appropriate
* Implement native_multi_head_attention(..) on CPU
The backward implementation is still WIP, but the idea would be to:
* Hook these up in derivatives.yaml
Implement native_multi_head_attention_backward(..) on GPU, extracting out bits and pieces out of LS/DS (not FT since it’s inference only)
* Implement native_multi_head_attention_backward(..) on CPU
* In torch.nn.functional.multi_head_attention_forward https://github.com/pytorch/pytorch/blob/23321ba7a3b634ee734455aab4a984689802cad0/torch/nn/functional.py#L4953, add some conditionals to check if we are being called in a BERT/ViT-style encoder fashion, and invoke the native function directly.
Test Plan: TODO
Reviewed By: mikekgfb
Differential Revision: D31829981
fbshipit-source-id: c430344d91ba7a5fbee3138e50b3e62efbb33d96