pytorch
105e58a5 - [Foreach Reduction] Use `OpMathType` tensor for intermediate results

Commit
3 years ago
[Foreach Reduction] Use `OpMathType` tensor for intermediate results Follow-up of https://github.com/pytorch/pytorch/pull/62646 In APEX, multi_tensor_norm only supports float and half and the dtype of `output` and `output_per_tensor` is hardcoded as single-precision (see https://github.com/NVIDIA/apex/blob/ae757634efa26a4ed852324f1d32f2159774997b/csrc/multi_tensor_l2norm_kernel.cu#L318). But in my previous PR, any tensor created in the kernel has the same dtype as the input tensors. I'm not quite sure why I didn't see any failures in the previous PR but internal math should be performed in 32 bits for 16-bit tensors, in my opinion. rel: https://github.com/pytorch/pytorch/issues/58833 cc @ptrblck @mcarilli @ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/68107 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading