Use int64 in pdist kernel to handle batches >= 46342 #30583 (#31593)
Summary:
Currently `torch.pdist` yields an illegal CUDA memory access for batch sizes >= 46342 as reported by SsnL in https://github.com/pytorch/pytorch/issues/30583.
Thanks for the minimal code reproduction, btw! ;)
Reason for this bug:
The calculation if `i` in the [`pdist_kerne_cuda_impl`](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L112) might overflow, if a tensor with a `batch size >= 46342` is passed to `torch.pdist`.
Detailed description:
* `result` is resizes as ` n * (n - 1) / 2 = 1073767311` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/Distance.cpp#L140))
* `grid` is initialized as `result.numel()` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L246))
* `k` is assigned to the `blockIdx.x` as an `int32` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L108))
* `i` is calculated using `2 * k >= 2147534622` ([line of code](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L112)), which overflows, since `2147534622 > 2147483647 (int32_max)`.
Using `const int64_t k = blockIdx.x;` would solve the illegal memory access. This seems also be done for [`cdist_kernel_cuda_impl`](https://github.com/pytorch/pytorch/blob/46ad80c8395379be5ba17624fd5dbad8e7a8e8d2/aten/src/ATen/native/cuda/DistanceKernel.cu#L198-L201).
However, we might expect a slowdown, so I've timed the current PyTorch master vs. this PR:
(tested with `x = torch.randn(x.size(0), 128)` on a V100)
|x.size(0) | int32 idx | int64 idx | slowdown |
|----------|-----------|-----------|----------|
| 50000 | - | 4.4460 | - |
| 25000 | 1.02522 | 1.10869 | 7.53% |
| 12500 | 0.25182 | 0.27277 | 7.68% |
| 6250 | 0.06291 | 0.06817 | 7.72% |
| 3125 | 0.01573 | 0.01704 | 7.69% |
| 1562 | 0.00393 | 0.00426 | 7.75% |
While checking the backward kernel, it seems I'm triggering another error with a size limit of
```python
x = torch.randn(1449, 1, device='cuda', requires_grad=True)
out = torch.pdist(x)
out.mean().backward()
> RuntimeError: CUDA error: invalid configuration argument
```
, while `[<=1448, 1]` works.
I'll take another look at this issue. Let me know, if the potential fix should go into this PR or if I should open a new issue.
CC ngimel, csarofeen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31593
Differential Revision: D19825571
Pulled By: ngimel
fbshipit-source-id: ace9ccab49f3cf0ce894cdb6daef0795e2e8ec03