pytorch
63e1f12b - Speedup bincount and histc on CUDA (#97090)

Commit
1 year ago
Speedup bincount and histc on CUDA (#97090) This is to speed up torch.bincount and torch.histc on CUDA. 1. Speed up int64_t gpuAtomicAdd, 2. and optimize the histogram kernel. # Fixes #96626 After speedup, time cost in #96626 would be ``` ... (run 2 times and ignore the first run) case 1 CPU 0.0003631114959716797 seconds case 1 CUDA 0.0005860328674316406 seconds case 2 CPU 0.0013742446899414062 seconds case 2 CUDA 0.0008623600006103516 seconds ``` Note that in "*case 1 CUDA*", the **max** op takes the most time, i.e., https://github.com/pytorch/pytorch/blob/5ee5a164ffeb7b7a167c53009fb8fe5f5bd439d9/aten/src/ATen/native/cuda/SummaryOps.cu#L334-L335, which is not to be optimized in this PR. # Benchmark Time is measured on i7-10700 + RTX 3080, Ubuntu 22.04 (in WSL). The baseline is PyTorch 2.0.0+cu117. My dev version of PyTorch is compiled with CUDA 11.8. Each case is measured 15 times to take the median. ## torch.bincount #elem | nbins | distribution | CPU | PyTorch 2.0.0 | this PR | speedup -- | -- | -- | -- | -- | -- | -- 2**20 | 80 | random.uniform | 0.000834 | 0.005783 | 0.000266 | 21.8x 2**20 | 80 | narrow in 1 bin | 0.001576 | 0.003967 | 0.000563 | 7.0x 2**20 | 500 | random.uniform | 0.000852 | 0.003641 | 0.000334 | 10.9x 2**20 | 500 | narrow in 1% bins | 0.000894 | 0.001878 | 0.000349 | 5.4x 2**20 | 2048 | random.uniform | 0.000891 | 0.000820 | 0.000298 | 2.8x 2**20 | 2048 | narrow in 1% bins | 0.000958 | 1.043251 | 0.000335 | 3,116.6x 2**26 | 80 | random.uniform | 0.067715 | 0.322409 | 0.003032 | 106.3x 2**26 | 80 | narrow in 1 bin | 0.110940 | 0.194644 | 0.017651 | 11.0x 2**26 | 500 | random.uniform | 0.066666 | 0.192302 | 0.002535 | 75.8x 2**26 | 500 | narrow in 1% bins | 0.066130 | 0.092237 | 0.005462 | 16.9x 2**26 | 2048 | random.uniform | 0.066371 | 0.035308 | 0.002476 | 14.3x 2**26 | 2048 | narrow in 1% bins | 0.068453 | 72.122858 | 0.003185 | 22,644.3x ## torch.histc (float32) #elem | nbins | distribution | CPU | PyTorch 2.0.0 | this PR | speedup -- | -- | -- | -- | -- | -- | -- 2**20 | 80 | random.uniform | 0.001261 | 0.000145 | 9.47E-05 | 1.5x 2**20 | 80 | narrow in 1 bin | 0.001074 | 0.000356 | 0.000311 | 1.1x 2**20 | 500 | random.uniform | 0.001162 | 0.000227 | 9.18E-05 | 2.5x 2**20 | 500 | narrow in 1% bins | 0.001082 | 0.000201 | 0.000152 | 1.3x 2**20 | 2048 | random.uniform | 0.001100 | 0.000203 | 0.000118 | 1.7x 2**20 | 2048 | narrow in 1% bins | 0.001089 | 0.000396 | 0.000107 | 3.7x 2**26 | 80 | random.uniform | 0.064219 | 0.001170 | 0.000786 | 1.5x 2**26 | 80 | narrow in 1 bin | 0.056471 | 0.013283 | 0.011939 | 1.1x 2**26 | 500 | random.uniform | 0.078183 | 0.003411 | 0.000562 | 6.1x 2**26 | 500 | narrow in 1% bins | 0.056711 | 0.002763 | 0.002738 | 1.0x 2**26 | 2048 | random.uniform | 0.059296 | 0.003503 | 0.000533 | 6.6x 2**26 | 2048 | narrow in 1% bins | 0.061754 | 0.015703 | 0.000962 | 16.3x ## torch.histc (int64) #elem | nbins | distribution | CPU | PyTorch 2.0.0 | this PR | speedup -- | -- | -- | -- | -- | -- | -- 2**20 | 80 | random.uniform | N/A | 0.005614 | 9.47E-05 | 59.3x 2**20 | 80 | narrow in 1 bin | N/A | 0.003799 | 0.000395 | 9.6x 2**20 | 500 | random.uniform | N/A | 0.003665 | 9.58E-05 | 38.2x 2**20 | 500 | narrow in 1% bins | N/A | 0.001760 | 0.000178 | 9.9x 2**20 | 2048 | random.uniform | N/A | 0.000693 | 0.000111 | 6.2x 2**20 | 2048 | narrow in 1% bins | N/A | 1.082904 | 0.000123 | 8,802.4x 2**26 | 80 | random.uniform | N/A | 0.320400 | 0.001145 | 279.9x 2**26 | 80 | narrow in 1 bin | N/A | 0.193668 | 0.015229 | 12.7x 2**26 | 500 | random.uniform | N/A | 0.182897 | 0.000823 | 222.2x 2**26 | 500 | narrow in 1% bins | N/A | 0.089363 | 0.00376 | 23.8x 2**26 | 2048 | random.uniform | N/A | 0.033190 | 0.000832 | 39.9x 2**26 | 2048 | narrow in 1% bins | N/A | 71.721012 | 0.001525 | 47,017.8x ## Banchmark code Here is the benchmark code: ```python3 import time import torch cases = [ ("bincount bins=80 wide ", torch.randint(80, [2**20]), lambda x: torch.bincount(x, minlength=80)), ("bincount bins=80 narrow", torch.randint(1, [2**20]), lambda x: torch.bincount(x, minlength=80)), ("bincount bins=500 wide ", torch.randint(500, [2**20]), lambda x: torch.bincount(x, minlength=500)), ("bincount bins=500 narrow", torch.randint(5, [2**20]), lambda x: torch.bincount(x, minlength=500)), ("bincount bins=2048 wide ", torch.randint(2048, [2**20]), lambda x: torch.bincount(x, minlength=2048)), ("bincount bins=2048 narrow", torch.randint(20, [2**20]), lambda x: torch.bincount(x, minlength=2048)), ("histc_float bins=80 wide ", torch.rand(2**20), lambda x: torch.histc(x, bins=80, min=0., max=1.)), ("histc_float bins=80 narrow", torch.rand(2**20)*.01, lambda x: torch.histc(x, bins=80, min=0., max=1.)), ("histc_float bins=500 wide ", torch.rand(2**20), lambda x: torch.histc(x, bins=500, min=0., max=1.)), ("histc_float bins=500 narrow", torch.rand(2**20)*.01, lambda x: torch.histc(x, bins=500, min=0., max=1.)), ("histc_float bins=2048 wide ", torch.rand(2**20), lambda x: torch.histc(x, bins=2048, min=0., max=1.)), ("histc_float bins=2048 narrow", torch.rand(2**20)*.01, lambda x: torch.histc(x, bins=2048, min=0., max=1.)), ("histc_int bins=80 wide ", torch.randint(80, [2**20]), lambda x: torch.histc(x, bins=80, min=0., max=80.)), ("histc_int bins=80 narrow", torch.randint(1, [2**20]), lambda x: torch.histc(x, bins=80, min=0., max=80.)), ("histc_int bins=500 wide ", torch.randint(500, [2**20]), lambda x: torch.histc(x, bins=500, min=0., max=500.)), ("histc_int bins=500 narrow", torch.randint(5, [2**20]), lambda x: torch.histc(x, bins=500, min=0., max=500.)), ("histc_int bins=2048 wide ", torch.randint(2048, [2**20]), lambda x: torch.histc(x, bins=2048, min=0., max=2048.)), ("histc_int bins=2048 narrow", torch.randint(20, [2**20]), lambda x: torch.histc(x, bins=2048, min=0., max=2048.)), ] def test(case, device): name, x, func = case x = x.to(device) time_samples = [] for _ in range(15): torch.cuda.synchronize() t1 = time.time() func(x) torch.cuda.synchronize() t2 = time.time() time_samples.append(t2 - t1) median = sorted(time_samples)[len(time_samples) // 2] print(device, name, median) for case in cases: test(case, device="cuda") # for case in cases: # test(case, device="cpu") ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/97090 Approved by: https://github.com/ngimel
Author
Committer
Parents
Loading