properly compute batch_element_count (#82927)
Turns out sometimes local_batches can be completely bogus (I thought for masked softmax they are guaranteed to be equal to WARP_BATCH), so to compute real number of elements it needs to be taken into account.
cc @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82927
Approved by: https://github.com/erichan1