pytorch
7443c90f - optimize non lastdim softmax bf16 (#60371)

Commit
3 years ago
optimize non lastdim softmax bf16 (#60371) Summary: Here is the PR to enable the softmax calculation with data type of `bfloat16` when not along the last dim. * Use bf16 specialization for forward calculation to reduce the bf16/fp32 cast in vec template. * Release the bf16 limitation for backward calculation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/60371 Reviewed By: ejguan Differential Revision: D29563109 Pulled By: cpuhrsch fbshipit-source-id: f6b439fa3850a6c633f35db65ea3d735b747863e
Parents
Loading