Remove sync for randperm on small tensors. (#54113)
Summary:
For small tensors, it is known that GPU operates slower than CPU. However, offloading to CPU causes host <--> device sync. As a result, although offloading to CPU has better microbenchmarks, it often hurts instead of benefits the end-to-end performance, and it could be a blocker for CUDA graphs. After discussion with mcarilli and ptrblck, we think it might be good to just remove this piece of code and let it be slow.
Microbenchmarks:
```python
def run50_sync(f):
for _ in range(50):
f()
torch.cuda.synchronize()
torch.cuda.synchronize()
%timeit run50_sync(lambda: torch.randperm(3, device='cuda'))
%timeit run50_sync(lambda: torch.randperm(30, device='cuda'))
%timeit run50_sync(lambda: torch.randperm(300, device='cuda'))
%timeit run50_sync(lambda: torch.randperm(3000, device='cuda'))
%timeit run50_sync(lambda: torch.randperm(30000, device='cuda'))
%timeit run50_sync(lambda: torch.randperm(300000, device='cuda'))
%timeit run50_sync(lambda: torch.randperm(3000000, device='cuda'))
%timeit run50_sync(lambda: torch.randperm(30000000, device='cuda'))
```
Before this PR:
```
5.79 ms ± 51.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.78 ms ± 92.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.17 ms ± 87.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.65 ms ± 69.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
17.6 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
21 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
104 ms ± 880 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
944 ms ± 3.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
After this PR:
```
7.22 ms ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.28 ms ± 9.03 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.25 ms ± 10.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.19 ms ± 5.83 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
9.76 ms ± 162 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.3 ms ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
69.3 ms ± 42.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
716 ms ± 1.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54113
Reviewed By: ezyang
Differential Revision: D28017958
Pulled By: ngimel
fbshipit-source-id: 660992d43ca449e61ce0cb0aa1dae554c9560a4e