pytorch
71dddec6 - Cast grad_input to half when input_dtype is half in _softmax_backward_data aten decomposition (#85497)

Commit
2 years ago
Cast grad_input to half when input_dtype is half in _softmax_backward_data aten decomposition (#85497) Fixes #85504 `_softmax_backward_data` and `_log_softmax_backward_data` cast `grad_input` to half when the `input_dtype` is half. When running with amp without the cast, consumer ops can trigger `RuntimeError: expected scalar type Float but found Half`. https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SoftMax.cpp#L70-L83 https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SoftMax.cpp#L102-L113 Pull Request resolved: https://github.com/pytorch/pytorch/pull/85497 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading