Filter 0's returned by exponential distribution (#53480)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/48841 for half datatype (it was fixed for other datatypes before).
The reason for https://github.com/pytorch/pytorch/issues/48841 happening for half was that `exponential_` for half was producing 0s.
Exponential distribution implementation on cuda is here https://github.com/pytorch/pytorch/blob/e08aae261397b8da3e71024bbeddfe0487185d1d/aten/src/ATen/native/cuda/DistributionTemplates.h#L535-L545
with `transformation::exponential` defined here
https://github.com/pytorch/pytorch/blob/e08aae261397b8da3e71024bbeddfe0487185d1d/aten/src/ATen/core/TransformationHelper.h#L113-L123
It takes a uniformly distributed random number and takes `log` of it. If necessary, the result is then converted to low precision datatype (half). To avoid 0's, before applying `log`, ones are replaced with std::nextafter(1,0). This seems fine, because log(1-eps) is still representable in half precision (`torch.tensor([1.], device="cuda").nextafter(torch.tensor([0.], device="cuda")).log().half()` produces 5.96e-8) , so casting to `scalar_t` should work. However, since fast log approximation is used (`__logf`), the log result is ~3e-9 instead of more accurate 5.96e-8, and underflows when casting to half. Using `::log` instead of fast approximation fixes it, however, it comes with ~20% perf penalty on exponential kernel for fp32 datatype, probably more for half.
Edit: alternative approach used now is to filter all small values returned by transformation. The result is equivalent to squashing of 1's to 1-eps that was used before, and computing correct log of 1-eps (which is -eps, exactly equal even for doubles). This doesn't incur noticeable performance hit.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53480
Reviewed By: mruberry
Differential Revision: D26924622
Pulled By: ngimel
fbshipit-source-id: dc1329e4773bf91f26af23c8afa0ae845cfb0937
Author
Natalia Gimelshein