don't hardcode mask type in mha (#68077)
Summary:
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68077
Reviewed By: zou3519
Differential Revision: D32292410
Pulled By: ngimel
fbshipit-source-id: 67213cf5474dc3f83e90e28cf5a823abb683a6f9