[Tensor Parallel] Simplify distribute for MHA (#100046)
This function is only called for nn.MHA or the custom MHA we use, and
if it is the former it is converted to the latter. So this check can actually
be an assert.
Differential Revision: [D45300396](https://our.internmc.facebook.com/intern/diff/D45300396/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100046
Approved by: https://github.com/wanchaol