Speedup bincount and histc on CUDA (#97090)
This is to speed up torch.bincount and torch.histc on CUDA.
1. Speed up int64_t gpuAtomicAdd,
2. and optimize the histogram kernel.
# Fixes #96626
After speedup, time cost in #96626 would be
```
... (run 2 times and ignore the first run)
case 1 CPU 0.0003631114959716797 seconds
case 1 CUDA 0.0005860328674316406 seconds
case 2 CPU 0.0013742446899414062 seconds
case 2 CUDA 0.0008623600006103516 seconds
```
Note that in "*case 1 CUDA*", the **max** op takes the most time, i.e., https://github.com/pytorch/pytorch/blob/5ee5a164ffeb7b7a167c53009fb8fe5f5bd439d9/aten/src/ATen/native/cuda/SummaryOps.cu#L334-L335, which is not to be optimized in this PR.
# Benchmark
Time is measured on i7-10700 + RTX 3080, Ubuntu 22.04 (in WSL). The baseline is PyTorch 2.0.0+cu117. My dev version of PyTorch is compiled with CUDA 11.8. Each case is measured 15 times to take the median.
## torch.bincount
#elem | nbins | distribution | CPU | PyTorch 2.0.0 | this PR | speedup
-- | -- | -- | -- | -- | -- | --
2**20 | 80 | random.uniform | 0.000834 | 0.005783 | 0.000266 | 21.8x
2**20 | 80 | narrow in 1 bin | 0.001576 | 0.003967 | 0.000563 | 7.0x
2**20 | 500 | random.uniform | 0.000852 | 0.003641 | 0.000334 | 10.9x
2**20 | 500 | narrow in 1% bins | 0.000894 | 0.001878 | 0.000349 | 5.4x
2**20 | 2048 | random.uniform | 0.000891 | 0.000820 | 0.000298 | 2.8x
2**20 | 2048 | narrow in 1% bins | 0.000958 | 1.043251 | 0.000335 | 3,116.6x
2**26 | 80 | random.uniform | 0.067715 | 0.322409 | 0.003032 | 106.3x
2**26 | 80 | narrow in 1 bin | 0.110940 | 0.194644 | 0.017651 | 11.0x
2**26 | 500 | random.uniform | 0.066666 | 0.192302 | 0.002535 | 75.8x
2**26 | 500 | narrow in 1% bins | 0.066130 | 0.092237 | 0.005462 | 16.9x
2**26 | 2048 | random.uniform | 0.066371 | 0.035308 | 0.002476 | 14.3x
2**26 | 2048 | narrow in 1% bins | 0.068453 | 72.122858 | 0.003185 | 22,644.3x
## torch.histc (float32)
#elem | nbins | distribution | CPU | PyTorch 2.0.0 | this PR | speedup
-- | -- | -- | -- | -- | -- | --
2**20 | 80 | random.uniform | 0.001261 | 0.000145 | 9.47E-05 | 1.5x
2**20 | 80 | narrow in 1 bin | 0.001074 | 0.000356 | 0.000311 | 1.1x
2**20 | 500 | random.uniform | 0.001162 | 0.000227 | 9.18E-05 | 2.5x
2**20 | 500 | narrow in 1% bins | 0.001082 | 0.000201 | 0.000152 | 1.3x
2**20 | 2048 | random.uniform | 0.001100 | 0.000203 | 0.000118 | 1.7x
2**20 | 2048 | narrow in 1% bins | 0.001089 | 0.000396 | 0.000107 | 3.7x
2**26 | 80 | random.uniform | 0.064219 | 0.001170 | 0.000786 | 1.5x
2**26 | 80 | narrow in 1 bin | 0.056471 | 0.013283 | 0.011939 | 1.1x
2**26 | 500 | random.uniform | 0.078183 | 0.003411 | 0.000562 | 6.1x
2**26 | 500 | narrow in 1% bins | 0.056711 | 0.002763 | 0.002738 | 1.0x
2**26 | 2048 | random.uniform | 0.059296 | 0.003503 | 0.000533 | 6.6x
2**26 | 2048 | narrow in 1% bins | 0.061754 | 0.015703 | 0.000962 | 16.3x
## torch.histc (int64)
#elem | nbins | distribution | CPU | PyTorch 2.0.0 | this PR | speedup
-- | -- | -- | -- | -- | -- | --
2**20 | 80 | random.uniform | N/A | 0.005614 | 9.47E-05 | 59.3x
2**20 | 80 | narrow in 1 bin | N/A | 0.003799 | 0.000395 | 9.6x
2**20 | 500 | random.uniform | N/A | 0.003665 | 9.58E-05 | 38.2x
2**20 | 500 | narrow in 1% bins | N/A | 0.001760 | 0.000178 | 9.9x
2**20 | 2048 | random.uniform | N/A | 0.000693 | 0.000111 | 6.2x
2**20 | 2048 | narrow in 1% bins | N/A | 1.082904 | 0.000123 | 8,802.4x
2**26 | 80 | random.uniform | N/A | 0.320400 | 0.001145 | 279.9x
2**26 | 80 | narrow in 1 bin | N/A | 0.193668 | 0.015229 | 12.7x
2**26 | 500 | random.uniform | N/A | 0.182897 | 0.000823 | 222.2x
2**26 | 500 | narrow in 1% bins | N/A | 0.089363 | 0.00376 | 23.8x
2**26 | 2048 | random.uniform | N/A | 0.033190 | 0.000832 | 39.9x
2**26 | 2048 | narrow in 1% bins | N/A | 71.721012 | 0.001525 | 47,017.8x
## Banchmark code
Here is the benchmark code:
```python3
import time
import torch
cases = [
("bincount bins=80 wide ", torch.randint(80, [2**20]), lambda x: torch.bincount(x, minlength=80)),
("bincount bins=80 narrow", torch.randint(1, [2**20]), lambda x: torch.bincount(x, minlength=80)),
("bincount bins=500 wide ", torch.randint(500, [2**20]), lambda x: torch.bincount(x, minlength=500)),
("bincount bins=500 narrow", torch.randint(5, [2**20]), lambda x: torch.bincount(x, minlength=500)),
("bincount bins=2048 wide ", torch.randint(2048, [2**20]), lambda x: torch.bincount(x, minlength=2048)),
("bincount bins=2048 narrow", torch.randint(20, [2**20]), lambda x: torch.bincount(x, minlength=2048)),
("histc_float bins=80 wide ", torch.rand(2**20), lambda x: torch.histc(x, bins=80, min=0., max=1.)),
("histc_float bins=80 narrow", torch.rand(2**20)*.01, lambda x: torch.histc(x, bins=80, min=0., max=1.)),
("histc_float bins=500 wide ", torch.rand(2**20), lambda x: torch.histc(x, bins=500, min=0., max=1.)),
("histc_float bins=500 narrow", torch.rand(2**20)*.01, lambda x: torch.histc(x, bins=500, min=0., max=1.)),
("histc_float bins=2048 wide ", torch.rand(2**20), lambda x: torch.histc(x, bins=2048, min=0., max=1.)),
("histc_float bins=2048 narrow", torch.rand(2**20)*.01, lambda x: torch.histc(x, bins=2048, min=0., max=1.)),
("histc_int bins=80 wide ", torch.randint(80, [2**20]), lambda x: torch.histc(x, bins=80, min=0., max=80.)),
("histc_int bins=80 narrow", torch.randint(1, [2**20]), lambda x: torch.histc(x, bins=80, min=0., max=80.)),
("histc_int bins=500 wide ", torch.randint(500, [2**20]), lambda x: torch.histc(x, bins=500, min=0., max=500.)),
("histc_int bins=500 narrow", torch.randint(5, [2**20]), lambda x: torch.histc(x, bins=500, min=0., max=500.)),
("histc_int bins=2048 wide ", torch.randint(2048, [2**20]), lambda x: torch.histc(x, bins=2048, min=0., max=2048.)),
("histc_int bins=2048 narrow", torch.randint(20, [2**20]), lambda x: torch.histc(x, bins=2048, min=0., max=2048.)),
]
def test(case, device):
name, x, func = case
x = x.to(device)
time_samples = []
for _ in range(15):
torch.cuda.synchronize()
t1 = time.time()
func(x)
torch.cuda.synchronize()
t2 = time.time()
time_samples.append(t2 - t1)
median = sorted(time_samples)[len(time_samples) // 2]
print(device, name, median)
for case in cases:
test(case, device="cuda")
# for case in cases:
# test(case, device="cpu")
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97090
Approved by: https://github.com/ngimel