Enable log_softmax and CrossEntropyLoss for bfloat16 (#24457)
Summary:
Enabled torch.nn.functional.log_softmax and torch.nn.CrossEntropyLoss for bfloat16 data type.
In order to do that, following dependency have to be enabled.
- RNE (round to nearest even)
- AccumulateType
- bfloat16 arithmetic operator overload
Also, we implement std::numeric_limits fully support for bfloat16 data type
background for dependency:
- RNE vs truncate
From torch.nn.CrossEntropyLoss test. input_size=(128, 1000)
RNE result:
float output: tensor(7.3981, dtype=torch.float32, grad_fn=<NllLossBackward>)
bfloat16 output: tensor(7.3125, dtype=torch.bfloat16, grad_fn=<NllLossBackward>)
truncate result:
float output: tensor(7.3981, dtype=torch.float32, grad_fn=<NllLossBackward>)
bfloat16 output: tensor(5.8750, dtype=torch.bfloat16, grad_fn=<NllLossBackward>)
- scalar_t vs AccumulateType (AccumulateType of bfloat16 is float)
AccumulateType is essential to keep accuracy, especially for reduction related operation.
we have verified it with both local case and real topology. It turns out that bfloat16 type accumulator would cause huge relative error when elements number is large, even more than 50%.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24457
Differential Revision: D17113018
Pulled By: ezyang
fbshipit-source-id: 8d61297ca118f9b5c6730a01efcf3a3704d2f206