Adds 3D attn_mask support to merge_masks() for Multihead Attention fast path (#98991)
Fixes #97409
Adds support for 3D attn_mask by always expanding attn_mask to 4D as per https://github.com/pytorch/pytorch/pull/98375#issuecomment-1499504721
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98991
Approved by: https://github.com/jbschlosser