[pytorch][nn] torch.nn.MultiheadAttention fix (#73053)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73053
We found we cannot hit the `self-attention` case (where q,k,v are the same) in https://fburl.com/code/ebgri4k5 when we have the flag "`batch_first=True`" in the input of torch.nn.MultiheadAttention. Since even when we have the same q,k,v, we need to do transpose op for them separately if "`batch_first=True`" in https://fburl.com/code/h26mlbti. After this op, q,k,v will be different objects and the conditions "`k is v` or `q is k`" will be broken.
We need to adjust the transpose strategy to make sure that it aligns with the logic in https://fburl.com/code/0mg6i4yn and we can still satisfy the `self-attention` and `encoder-decoder attention` conditions when "`batch_first=True`".
Test Plan:
Use ads Transformer 1x model (where q,k,v are the same) as example and run:
```
CUDA_VISIBLE_DEVICES=7 buck run mode/opt -c fbcode.nvcc_arch=a100 -c fbcode.enable_nccl_a2a=1 //hpc/models/ads:ads_transformer_1x_2021h1_launcher -- +launcher=local launcher.num_trainers=1 +data_loader=random +mode=mast model.shrink=true model.max_ind_range=1000 data_loader.num_batches=2500 profiling_trace=true
```
After this fix and given the input of multiheadAttention of shape (`batch_size=512, length=80, emb_dim=512`), we improve QPS from 4705 to 4937. This is because we save the unnecessary extra linear module instantiation in forward/backward when q,k,v are the same and `batch_first=true`. From the trace examples, we can see the difference:
Before fixing (see [trace example](https://our.intern.facebook.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/hpc.huaqingxiong.local/rank-0/rank-0.Feb_17_16_10_15.1653768.trace.gz&bucket=hpc_traces)) and the key part in the following screenshot:
{F702341108}
After fixing (see [trace example](https://our.intern.facebook.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/hpc.huaqingxiong.local/rank-0/rank-0.Feb_17_15_59_03.1533566.trace.gz&bucket=hpc_traces))
{F702343574}
Reviewed By: albanD
Differential Revision: D34321018
fbshipit-source-id: 2da5315c1fd349fdf7ebb21c25d636be216f5719
(cherry picked from commit 3d93fd68d948dc1d7533ac8039553258d3e1a46e)