pytorch
9bc68fcd - [pytorch] Accelerate indexing_backward_kernel with duplicates (#99441 attempt 2) (#100505)

Commit
1 year ago
[pytorch] Accelerate indexing_backward_kernel with duplicates (#99441 attempt 2) (#100505) By knowing the stride value ahead of time, we can simplify the kernel code as follows: If stride == 1 we can use the whole warp to reduce the gradients If stride < warp_size we don't need the internal while (start_feature < stride) loop as blockDim.x is always 32 This changes improve the performance of the kernel when duplicates are present and do not affect the performance with low amount of duplicates. The implementation is deterministic. The proposed implementation uses opmath_t to accumulate in registers the gradient values so when using FP16/BF16 it may overflow if the number of elements is large. This is different from the initial implementation who accumulates in scalar_t and does not overflow. In addition, when the stride is 1, we are using warp shuffles to sum the gradient so the order of the addition is slightly different than a reference implementation which causes some minor numerical differences when compared to a reference. TEST CODE: ``` # The first element is the number of iterations. # The second represents the number of unique elements. If # set to 0, the number of unique elements is equal to the # number of elements. # The remaining elements are the tensor dimensions. basic_indexing_tests = [ [10, 0, 12345], [10, 4, 12345], [10, 16, 512, 512, 32], [10, 0, 4, 4], [10, 0, 32, 32], [10, 8, 32, 32], [10, 8, 64, 32, 16], [10, 0, 64, 32, 16], [10, 16, 512, 512, 32], [10, 0, 675, 999, 13], [10, 0, 123, 456, 31], [10, 0, 512, 512, 32], [10, 4, 512, 512, 32], [10, 2, 512, 512, 32], [10, 0, 128, 128, 16, 16], [10, 8, 128, 126, 16, 16], [10, 4, 128, 126, 16, 16], [10, 0, 64, 64, 16, 16, 16], [10, 8, 64, 64, 16, 16, 16], [10, 2, 64, 64, 16, 16, 16], [10, 1, 64, 64, 16, 16, 16], ] def run_basic_indexing_on_device(x, index, expected, device_string, iters): x_dev = x.to(device_string) x_dev = x_dev.detach().requires_grad_() index_dev = index.to(device_string) # Run backward pass; keep gradients and measure time torch.cuda.synchronize() t_bw_s = time() for _ in range(iters): y = x_dev[index_dev] z = y.sum() z.backward() torch.cuda.synchronize() t_bw_s = (time() - t_bw_s) / iters return (x_dev.grad, t_bw_s) def run_basic_indexing_test(test_input): tensor_size = tuple(test_input[:5]) niters = test_input[0] num_unique = test_input[1] tensor_size = tuple(test_input[2:]) numel = 1 for dim in tensor_size: numel *= dim if num_unique == 0: num_unique = numel index = torch.randint(0, num_unique, tensor_size, dtype=torch.long, device="cpu") x = torch.randn((numel,), dtype=torch.float32, device="cuda") index = index.detach() x = x.detach().requires_grad_() (cpu_grad, t_bw_cpu) = run_basic_indexing_on_device(x, index, numel / 2, "cpu", 1) (gpu_grad, t_bw_gpu) = run_basic_indexing_on_device(x, index, numel / 2, "cuda", 1) max_delta = torch.max(torch.abs(cpu_grad - gpu_grad.to("cpu"))) missmatches = torch.nonzero(torch.abs(cpu_grad - gpu_grad.to("cpu"))) (gpu_grad_perf, t_gpu) = run_basic_indexing_on_device( x, index, numel / 2, "cuda", niters ) print( "test = {}, delta = {:.5f}, missmatches = {} duration_ms = {:.3f}".format( tuple(test_input), max_delta, missmatches, t_gpu * 1000.0 ) ) if torch.numel(missmatches) > 0: print("cpu grad = {}", cpu_grad[missmatches]) print("gpu grad = {}", gpu_grad[missmatches]) ``` RESULTS: ``` Default Implementation test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.726 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.867 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 80.514 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.689 test = (1, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.547 test = (1, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.537 test = (1, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.199 test = (1, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.584 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 80.055 test = (1, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.411 test = (1, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.419 test = (1, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.048 test = (1, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 307.633 test = (1, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 606.403 test = (1, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 4.099 test = (1, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 76.813 test = (1, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 148.760 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 16.547 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 317.583 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1204.800 test = (1, 1, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2412.133 Small Stride Kernel Version test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.904 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.156 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 308.878 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.566 test = (1, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.540 test = (1, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.550 test = (1, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.868 test = (1, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.656 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 307.856 test = (1, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 6.624 test = (1, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.837 test = (1, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 6.274 test = (1, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1127.040 test = (1, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2123.942 test = (1, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 3.282 test = (1, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 288.997 test = (1, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 547.267 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 12.844 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1178.934 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 4262.042 test = (1, 1, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8172.318 Stride 1 Kernel Version test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.692 test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.834 test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 81.023 test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.631 test = (100, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.491 test = (100, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.477 test = (50, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.561 test = (50, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.516 test = (16, 10, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 126.455 test = (10, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.238 test = (10, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.520 test = (10, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 7.854 test = (10, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 306.327 test = (10, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 610.498 test = (5, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 3.684 test = (5, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 75.604 test = (5, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 148.679 test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 16.525 test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 315.095 test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1214.715 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/100505 Approved by: https://github.com/ngimel
Committer
Parents
Loading