refactor multi_head_attention_forward (#56674)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56674
`torch.nn.functional.multi_head_attention_forward` supports a long tail of options and variations of the multihead attention computation. Its complexity is mostly due to arbitrating among options, preparing values in multiple ways, and so on - the attention computation itself is a small fraction of the implementation logic, which is relatively simple but can be hard to pick out.
The goal of this PR is to
- make the internal logic of `multi_head_attention_forward` less entangled and more readable, with the attention computation steps easily discernible from their surroundings.
- factor out simple helpers to perform the actual attention steps, with the aim of making them available to other attention-computing contexts.
Note that these changes should leave the signature and output of `multi_head_attention_forward` completely unchanged, so not BC-breaking. Later PRs should present new multihead attention entry points, but deprecating this one is out of scope for now.
Changes are in two parts:
- the implementation of `multi_head_attention_forward` has been extensively resequenced, which makes the rewrite look more total than it actually is. Changes to argument-processing logic are largely confined to a) minor perf tweaks/control flow tightening, b) error message improvements, and c) argument prep changes due to helper function factoring (e.g. merging `key_padding_mask` with `attn_mask` rather than applying them separately)
- factored helper functions are defined just above `multi_head_attention_forward`, with names prefixed with `_`. (A future PR may pair them with corresponding modules, but for now they're private.)
Test Plan: Imported from OSS
Reviewed By: gmagogsfm
Differential Revision: D28344707
Pulled By: bhosmer
fbshipit-source-id: 3bd8beec515182c3c4c339efc3bec79c0865cb9a