pytorch
296e1ba4 - Row and column select support for block compressed sparse tensors (#88733)

Commit
3 years ago
Row and column select support for block compressed sparse tensors (#88733) As in the title: - Support `select` and `select_copy` on block sparse compressed tensors - Fixes incorrect results when selecting dense dimensions The PR also improves the performance of indexing sparse compressed tensors considerably: <details> Before: ```python In [3]: a=torch.rand((1000, 1000)).to_sparse_csr() In [4]: %timeit a.select(0, 0) 606 µs ± 4.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [5]: %timeit a.select(1, 0) 527 µs ± 57.7 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [6]: %timeit a[0, 0] 617 µs ± 3.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [7]: a = a.cuda() In [8]: %timeit a.select(0, 0); torch.cuda.synchronize(); 1.19 ms ± 137 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [9]: %timeit a.select(1, 0); torch.cuda.synchronize(); 1.2 ms ± 119 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [10]: %timeit a[0, 0]; torch.cuda.synchronize(); 1.23 ms ± 482 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` This PR: ```python In [3]: a=torch.rand((1000, 1000)).to_sparse_csr() In [4]: %timeit a.select(0, 0) 4.75 µs ± 8.94 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) In [5]: %timeit a.select(1, 0) 565 µs ± 156 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [6]: %timeit a[0, 0] 13.1 µs ± 435 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each) In [7]: a = a.cuda() In [8]: %timeit a.select(0, 0); torch.cuda.synchronize(); 21.6 µs ± 23.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) In [9]: %timeit a.select(1, 0); torch.cuda.synchronize(); 1.15 ms ± 3.13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) In [10]: %timeit a[0, 0]; torch.cuda.synchronize(); 63.7 µs ± 2.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) ``` </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/88733 Approved by: https://github.com/nikitaved, https://github.com/amjames, https://github.com/cpuhrsch
Author
Committer
Parents
Loading