Port nonzero cuda from THC to ATen (#44259)
Summary:
1) Ports nonzero from THC to ATen
2) replaces most thrust uses with cub, to avoid synchronization and to improve performance. There is still one necessary synchronization point, communicating number of nonzero elements from GPU to CPU
3) slightly changes algorithm, now we first compute the number of nonzeros, and then allocate correct-sized output, instead of allocating full-sized output as was done before, to account for possibly all elements being non-zero
4) unfortunately, since the last transforms are still done with thrust, 2) is slightly beside the point, however it is a step towards a future without thrust
4) hard limits the number of elements in the input tensor to MAX_INT. Previous implementation allocated a Long tensor with the size ndim*nelements, so that would be at least 16 GB for a tensor with MAX_INT elements. It is reasonable to say that larger tensors could not be used anyway.
Benchmarking is done for tensors with approximately half non-zeros
<details><summary>Benchmarking script</summary>
<p>
```
import torch
from torch.utils._benchmark import Timer
from torch.utils._benchmark import Compare
import sys
device = "cuda"
results = []
for numel in (1024 * 128,):#, 1024 * 1024, 1024 * 1024 * 128):
inp = torch.randint(2, (numel,), device="cuda", dtype=torch.float)
for ndim in range(2,3):#(1,4):
if ndim == 1:
shape = (numel,)
elif ndim == 2:
shape = (1024, numel // 1024)
else:
shape = (1024, 128, numel // 1024 // 128)
inp = inp.reshape(shape)
repeats = 3
timer = Timer(stmt="torch.nonzero(inp, as_tuple=False)", label="Nonzero", sub_label=f"number of elts {numel}",
description = f"ndim {ndim}", globals=globals())
for i in range(repeats):
results.append(timer.blocked_autorange())
print(f"\rnumel {numel} ndim {ndim}", end="")
sys.stdout.flush()
comparison = Compare(results)
comparison.print()
```
</p>
</details>
### Results
Before:
```
[--------------------------- Nonzero ---------------------------]
| ndim 1 | ndim 2 | ndim 3
1 threads: ------------------------------------------------------
number of elts 131072 | 55.2 | 71.7 | 90.5
number of elts 1048576 | 113.2 | 250.7 | 497.0
number of elts 134217728 | 8353.7 | 23809.2 | 54602.3
Times are in microseconds (us).
```
After:
```
[-------------------------- Nonzero --------------------------]
| ndim 1 | ndim 2 | ndim 3
1 threads: ----------------------------------------------------
number of elts 131072 | 48.6 | 79.1 | 90.2
number of elts 1048576 | 64.7 | 134.2 | 161.1
number of elts 134217728 | 3748.8 | 7881.3 | 9953.7
Times are in microseconds (us).
```
There's a real regression for smallish 2D tensor due to added work of computing number of nonzero elements, however, for other sizes there are significant gains, and there are drastically lower memory requirements. Perf gains would be even larger for tensors with fewer nonzeros.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44259
Reviewed By: izdeby
Differential Revision: D23581955
Pulled By: ngimel
fbshipit-source-id: 0b99a767fd60d674003d83f0848dc550d7a363dc
Author
Natalia Gimelshein