pytorch
be8c7c06 - [Tensor Parallel] Simplify distribute for MHA (#100046)

Commit
1 year ago
[Tensor Parallel] Simplify distribute for MHA (#100046) This function is only called for nn.MHA or the custom MHA we use, and if it is the former it is converted to the latter. So this check can actually be an assert. Differential Revision: [D45300396](https://our.internmc.facebook.com/intern/diff/D45300396/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/100046 Approved by: https://github.com/wanchaol
Author
Committer
Parents
Loading