pytorch
ecc6358d - Port nonzero cuda from THC to ATen (#44259)

Commit
4 years ago
Port nonzero cuda from THC to ATen (#44259) Summary: 1) Ports nonzero from THC to ATen 2) replaces most thrust uses with cub, to avoid synchronization and to improve performance. There is still one necessary synchronization point, communicating number of nonzero elements from GPU to CPU 3) slightly changes algorithm, now we first compute the number of nonzeros, and then allocate correct-sized output, instead of allocating full-sized output as was done before, to account for possibly all elements being non-zero 4) unfortunately, since the last transforms are still done with thrust, 2) is slightly beside the point, however it is a step towards a future without thrust 4) hard limits the number of elements in the input tensor to MAX_INT. Previous implementation allocated a Long tensor with the size ndim*nelements, so that would be at least 16 GB for a tensor with MAX_INT elements. It is reasonable to say that larger tensors could not be used anyway. Benchmarking is done for tensors with approximately half non-zeros <details><summary>Benchmarking script</summary> <p> ``` import torch from torch.utils._benchmark import Timer from torch.utils._benchmark import Compare import sys device = "cuda" results = [] for numel in (1024 * 128,):#, 1024 * 1024, 1024 * 1024 * 128): inp = torch.randint(2, (numel,), device="cuda", dtype=torch.float) for ndim in range(2,3):#(1,4): if ndim == 1: shape = (numel,) elif ndim == 2: shape = (1024, numel // 1024) else: shape = (1024, 128, numel // 1024 // 128) inp = inp.reshape(shape) repeats = 3 timer = Timer(stmt="torch.nonzero(inp, as_tuple=False)", label="Nonzero", sub_label=f"number of elts {numel}", description = f"ndim {ndim}", globals=globals()) for i in range(repeats): results.append(timer.blocked_autorange()) print(f"\rnumel {numel} ndim {ndim}", end="") sys.stdout.flush() comparison = Compare(results) comparison.print() ``` </p> </details> ### Results Before: ``` [--------------------------- Nonzero ---------------------------] | ndim 1 | ndim 2 | ndim 3 1 threads: ------------------------------------------------------ number of elts 131072 | 55.2 | 71.7 | 90.5 number of elts 1048576 | 113.2 | 250.7 | 497.0 number of elts 134217728 | 8353.7 | 23809.2 | 54602.3 Times are in microseconds (us). ``` After: ``` [-------------------------- Nonzero --------------------------] | ndim 1 | ndim 2 | ndim 3 1 threads: ---------------------------------------------------- number of elts 131072 | 48.6 | 79.1 | 90.2 number of elts 1048576 | 64.7 | 134.2 | 161.1 number of elts 134217728 | 3748.8 | 7881.3 | 9953.7 Times are in microseconds (us). ``` There's a real regression for smallish 2D tensor due to added work of computing number of nonzero elements, however, for other sizes there are significant gains, and there are drastically lower memory requirements. Perf gains would be even larger for tensors with fewer nonzeros. Pull Request resolved: https://github.com/pytorch/pytorch/pull/44259 Reviewed By: izdeby Differential Revision: D23581955 Pulled By: ngimel fbshipit-source-id: 0b99a767fd60d674003d83f0848dc550d7a363dc
Author
Natalia Gimelshein
Parents
Loading