Update MultiheadAttention module support key/value with different number of features and allow static key/value (#21288)
Summary:
The changes include:
1. Allow key/value to have different number of features with query. It supports the case when key and value have different feature dimensions.
2. Support three separate proj_weight, in addition to a single in_proj_weight. The proj_weight of key and value may have different dimension with that of query so three separate proj_weights are necessary. In case that key and value have same dimension as query, it is preferred to use a single large proj_weight for performance reason. However, it should be noted that using a single large weight or three separate weights is a size-dependent decision.
3. Give an option to use static k and v in the multihead_attn operator (see saved_k and saved_v). Those static key/value tensors can now be re-used when training the model.
4. Add more test cases to cover the arguments.
Note: current users should not be affected by the changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21288
Differential Revision: D15738808
Pulled By: zhangguanheng66
fbshipit-source-id: 288b995787ad55fba374184b3d15b5c6fe9abb5c