pytorch
bb0f299f - Update MultiheadAttention module support key/value with different number of features and allow static key/value (#21288)

Commit
5 years ago
Update MultiheadAttention module support key/value with different number of features and allow static key/value (#21288) Summary: The changes include: 1. Allow key/value to have different number of features with query. It supports the case when key and value have different feature dimensions. 2. Support three separate proj_weight, in addition to a single in_proj_weight. The proj_weight of key and value may have different dimension with that of query so three separate proj_weights are necessary. In case that key and value have same dimension as query, it is preferred to use a single large proj_weight for performance reason. However, it should be noted that using a single large weight or three separate weights is a size-dependent decision. 3. Give an option to use static k and v in the multihead_attn operator (see saved_k and saved_v). Those static key/value tensors can now be re-used when training the model. 4. Add more test cases to cover the arguments. Note: current users should not be affected by the changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/21288 Differential Revision: D15738808 Pulled By: zhangguanheng66 fbshipit-source-id: 288b995787ad55fba374184b3d15b5c6fe9abb5c
Author
Guanheng Zhang
Parents
Loading