[pytorch] Accelerate indexing_backward_kernel with duplicates (#99441 attempt 2) (#100505)
By knowing the stride value ahead of time, we can simplify the kernel code as follows:
If stride == 1 we can use the whole warp to reduce the gradients
If stride < warp_size we don't need the internal while (start_feature < stride) loop as blockDim.x is always 32
This changes improve the performance of the kernel when duplicates are present and do not affect the performance with low amount of duplicates. The implementation is deterministic.
The proposed implementation uses opmath_t to accumulate in registers the gradient values so when using FP16/BF16 it may overflow if the number of elements is large. This is different from the initial implementation who accumulates in scalar_t and does not overflow. In addition, when the stride is 1, we are using warp shuffles to sum the gradient so the order of the addition is slightly different than a reference implementation which causes some minor numerical differences when compared to a reference.
TEST CODE:
```
# The first element is the number of iterations.
# The second represents the number of unique elements. If
# set to 0, the number of unique elements is equal to the
# number of elements.
# The remaining elements are the tensor dimensions.
basic_indexing_tests = [
[10, 0, 12345],
[10, 4, 12345],
[10, 16, 512, 512, 32],
[10, 0, 4, 4],
[10, 0, 32, 32],
[10, 8, 32, 32],
[10, 8, 64, 32, 16],
[10, 0, 64, 32, 16],
[10, 16, 512, 512, 32],
[10, 0, 675, 999, 13],
[10, 0, 123, 456, 31],
[10, 0, 512, 512, 32],
[10, 4, 512, 512, 32],
[10, 2, 512, 512, 32],
[10, 0, 128, 128, 16, 16],
[10, 8, 128, 126, 16, 16],
[10, 4, 128, 126, 16, 16],
[10, 0, 64, 64, 16, 16, 16],
[10, 8, 64, 64, 16, 16, 16],
[10, 2, 64, 64, 16, 16, 16],
[10, 1, 64, 64, 16, 16, 16],
]
def run_basic_indexing_on_device(x, index, expected, device_string, iters):
x_dev = x.to(device_string)
x_dev = x_dev.detach().requires_grad_()
index_dev = index.to(device_string)
# Run backward pass; keep gradients and measure time
torch.cuda.synchronize()
t_bw_s = time()
for _ in range(iters):
y = x_dev[index_dev]
z = y.sum()
z.backward()
torch.cuda.synchronize()
t_bw_s = (time() - t_bw_s) / iters
return (x_dev.grad, t_bw_s)
def run_basic_indexing_test(test_input):
tensor_size = tuple(test_input[:5])
niters = test_input[0]
num_unique = test_input[1]
tensor_size = tuple(test_input[2:])
numel = 1
for dim in tensor_size:
numel *= dim
if num_unique == 0:
num_unique = numel
index = torch.randint(0, num_unique, tensor_size, dtype=torch.long, device="cpu")
x = torch.randn((numel,), dtype=torch.float32, device="cuda")
index = index.detach()
x = x.detach().requires_grad_()
(cpu_grad, t_bw_cpu) = run_basic_indexing_on_device(x, index, numel / 2, "cpu", 1)
(gpu_grad, t_bw_gpu) = run_basic_indexing_on_device(x, index, numel / 2, "cuda", 1)
max_delta = torch.max(torch.abs(cpu_grad - gpu_grad.to("cpu")))
missmatches = torch.nonzero(torch.abs(cpu_grad - gpu_grad.to("cpu")))
(gpu_grad_perf, t_gpu) = run_basic_indexing_on_device(
x, index, numel / 2, "cuda", niters
)
print(
"test = {}, delta = {:.5f}, missmatches = {} duration_ms = {:.3f}".format(
tuple(test_input), max_delta, missmatches, t_gpu * 1000.0
)
)
if torch.numel(missmatches) > 0:
print("cpu grad = {}", cpu_grad[missmatches])
print("gpu grad = {}", gpu_grad[missmatches])
```
RESULTS:
```
Default Implementation
test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.726
test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.867
test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 80.514
test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.689
test = (1, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.547
test = (1, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.537
test = (1, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.199
test = (1, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.584
test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 80.055
test = (1, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.411
test = (1, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.419
test = (1, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.048
test = (1, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 307.633
test = (1, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 606.403
test = (1, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 4.099
test = (1, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 76.813
test = (1, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 148.760
test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 16.547
test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 317.583
test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1204.800
test = (1, 1, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2412.133
Small Stride Kernel Version
test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.904
test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.156
test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 308.878
test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.566
test = (1, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.540
test = (1, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.550
test = (1, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2.868
test = (1, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.656
test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 307.856
test = (1, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 6.624
test = (1, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.837
test = (1, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 6.274
test = (1, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1127.040
test = (1, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 2123.942
test = (1, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 3.282
test = (1, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 288.997
test = (1, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 547.267
test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 12.844
test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1178.934
test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 4262.042
test = (1, 1, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8172.318
Stride 1 Kernel Version
test = (1, 0, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.692
test = (1, 4, 12345), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.834
test = (1, 16, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 81.023
test = (1, 0, 4, 4), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.631
test = (100, 0, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.491
test = (100, 8, 32, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.477
test = (50, 8, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.561
test = (50, 0, 64, 32, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 0.516
test = (16, 10, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 126.455
test = (10, 0, 675, 999, 13), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 8.238
test = (10, 0, 123, 456, 31), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1.520
test = (10, 0, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 7.854
test = (10, 4, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 306.327
test = (10, 2, 512, 512, 32), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 610.498
test = (5, 0, 128, 128, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 3.684
test = (5, 8, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 75.604
test = (5, 4, 128, 126, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 148.679
test = (1, 0, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 16.525
test = (1, 8, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 315.095
test = (1, 2, 64, 64, 16, 16, 16), delta = 0.00000, missmatches = tensor([], size=(0, 1), dtype=torch.int64) duration_ms = 1214.715
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100505
Approved by: https://github.com/ngimel