[fix] cauchy sampling inf on cuda (#60186)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/59144
As pointed by ngimel, the issue is indeed with calling `tan`.
However the C++ `std::tan` [documenation](https://en.cppreference.com/w/cpp/numeric/math/tan) states that
```
The function has mathematical poles at π(1/2 + n); however no common floating-point representation
is able to represent π/2 exactly, thus there is no value of the argument for which a pole error occurs.
```
All `torch.tan`,`numpy.tan` and `math.tan` are compliant with the above statement.
<details>
```python
import torch
import math
import numpy as np
# Single Precision
print(torch.tan(torch.tensor(math.pi, device='cuda', dtype=torch.float32) * 0.5))
print(np.tan(np.array(np.pi, dtype=np.float32) * 0.5))
# Double Precision
print(math.tan(math.pi * 0.5))
print(torch.tan(torch.tensor(math.pi, device='cuda', dtype=torch.double) * 0.5))
print(np.tan(np.array(np.pi, dtype=np.float64) * 0.5))
```
Output
```
tensor(-22877334., device='cuda:0')
-22877332.42885646
1.633123935319537e+16
tensor(1.6331e+16, device='cuda:0', dtype=torch.float64)
1.633123935319537e+16
```
</details>
So this issue stems from the use of `__tanf` faster approximation of tan from CUDA library (for float16, bfloat16 and float).
https://github.com/pytorch/pytorch/blob/8a839c54788e6551ead9a018993c4995e02f3219/aten/src/ATen/NumericUtils.h#L91-L100
The fix in the PR is to use the **slower** but more correct version.
Benchmark::
```
[ cauchy : input dtype torch.float16 device cuda ]
| Before | After
1 threads: -------------------------------------
(128,) | 3.8 | 4.3
(256, 128) | 3.8 | 4.2
(2, 512, 256) | 3.8 | 4.2
(2, 64, 256, 128) | 22.8 | 29.6
(4, 2, 512, 256, 128) | 649.6 | 869.3
Times are in microseconds (us).
[ cauchy : input dtype torch.bfloat16 device cuda ]
| Before | After
1 threads: -------------------------------------
(128,) | 3.8 | 4.3
(256, 128) | 3.8 | 4.3
(2, 512, 256) | 3.8 | 4.3
(2, 64, 256, 128) | 23.8 | 30.8
(4, 2, 512, 256, 128) | 682.5 | 904.2
Times are in microseconds (us).
[ cauchy : input dtype torch.float32 device cuda ]
| Before | After
1 threads: --------------------------------------
(128,) | 3.8 | 4.2
(256, 128) | 3.7 | 4.2
(2, 512, 256) | 3.7 | 4.2
(2, 64, 256, 128) | 35.3 | 37.1
(4, 2, 512, 256, 128) | 1020.0 | 1058.3
Times are in microseconds (us).
[- cauchy : input dtype torch.float64 device cuda ]
| Before | After
1 threads: ----------------------------------------
(128,) | 3.8 | 4.2
(256, 128) | 8.0 | 8.0
(2, 512, 256) | 46.0 | 46.0
(2, 64, 256, 128) | 669.2 | 669.4
(4, 2, 512, 256, 128) | 21255.0 | 21262.1
Times are in microseconds (us).
```
<details>
Benchmark Script:
```python
import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
import pickle
print('Using pytorch %s' % (torch.__version__))
cuda_shapes = [(128,), (256, 128), (2, 512, 256), (2, 64, 256, 128), (4, 2, 512, 256, 128)]
cuda_dtypes = [torch.half, torch.bfloat16, torch.float, torch.double]
results = []
repeats = 10
for device in ['cuda']:
dtypes = cuda_dtypes
shapes = cuda_shapes
for dtype in dtypes:
for shape in shapes:
t = torch.randn(shape, device=device, dtype=dtype) * 10
tasks = [("t.cauchy_()", "After", "")]
timers = [Timer(stmt=stmt, label=f"cauchy : input dtype {dtype} device {device}", sub_label=f"{(shape)}", description=desc, globals=globals()) for stmt, desc, label in tasks]
for i, timer in enumerate(timers * repeats):
results.append(
timer.blocked_autorange()
)
print(f"\r{i + 1} / {len(timers) * repeats}", end="")
sys.stdout.flush()
with open('after-pr.pkl', 'wb') as f:
pickle.dump(results, f)
comparison = Compare(results)
comparison.print()
```
Compare Script:
```
import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
import pickle
with open('before-pr.pkl', 'rb') as f:
after_results = pickle.load(f)
with open('after-pr.pkl', 'rb') as f:
before_results = pickle.load(f)
comparison = Compare(after_results + before_results)
comparison.print()
```
</details>
TODO:
* [x] Add comment
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60186
Reviewed By: jbschlosser
Differential Revision: D29433897
Pulled By: ngimel
fbshipit-source-id: 9c5f14b83e3372bed72369f70eed9256c04385c6