optimize non lastdim softmax bf16 (#60371)
Summary:
Here is the PR to enable the softmax calculation with data type of `bfloat16` when not along the last dim.
* Use bf16 specialization for forward calculation to reduce the bf16/fp32 cast in vec template.
* Release the bf16 limitation for backward calculation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60371
Reviewed By: ejguan
Differential Revision: D29563109
Pulled By: cpuhrsch
fbshipit-source-id: f6b439fa3850a6c633f35db65ea3d735b747863e