pytorch
eabb0a08 - [pytorch][nn] torch.nn.MultiheadAttention fix (#73053)

Commit
2 years ago
[pytorch][nn] torch.nn.MultiheadAttention fix (#73053) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73053 We found we cannot hit the `self-attention` case (where q,k,v are the same) in https://fburl.com/code/ebgri4k5 when we have the flag "`batch_first=True`" in the input of torch.nn.MultiheadAttention. Since even when we have the same q,k,v, we need to do transpose op for them separately if "`batch_first=True`" in https://fburl.com/code/h26mlbti. After this op, q,k,v will be different objects and the conditions "`k is v` or `q is k`" will be broken. We need to adjust the transpose strategy to make sure that it aligns with the logic in https://fburl.com/code/0mg6i4yn and we can still satisfy the `self-attention` and `encoder-decoder attention` conditions when "`batch_first=True`". Test Plan: Use ads Transformer 1x model (where q,k,v are the same) as example and run: ``` CUDA_VISIBLE_DEVICES=7 buck run mode/opt -c fbcode.nvcc_arch=a100 -c fbcode.enable_nccl_a2a=1 //hpc/models/ads:ads_transformer_1x_2021h1_launcher -- +launcher=local launcher.num_trainers=1 +data_loader=random +mode=mast model.shrink=true model.max_ind_range=1000 data_loader.num_batches=2500 profiling_trace=true ``` After this fix and given the input of multiheadAttention of shape (`batch_size=512, length=80, emb_dim=512`), we improve QPS from 4705 to 4937. This is because we save the unnecessary extra linear module instantiation in forward/backward when q,k,v are the same and `batch_first=true`. From the trace examples, we can see the difference: Before fixing (see [trace example](https://our.intern.facebook.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/hpc.huaqingxiong.local/rank-0/rank-0.Feb_17_16_10_15.1653768.trace.gz&bucket=hpc_traces)) and the key part in the following screenshot: {F702341108} After fixing (see [trace example](https://our.intern.facebook.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/hpc.huaqingxiong.local/rank-0/rank-0.Feb_17_15_59_03.1533566.trace.gz&bucket=hpc_traces)) {F702343574} Reviewed By: albanD Differential Revision: D34321018 fbshipit-source-id: 2da5315c1fd349fdf7ebb21c25d636be216f5719 (cherry picked from commit 3d93fd68d948dc1d7533ac8039553258d3e1a46e)
Author
Huaqing Xiong
Committer
Parents
Loading