`index_select` for COO CUDA tensors. (#77551)
Brings a native CUDA implementation for `index_select`. Master silently converts CUDA tensors to CPU for CUDA support.
Case `nnz >> size` could be optimized similar to how https://github.com/pytorch/pytorch/pull/72710 is doing that.
Some benchmarks:
<details>
<summary>PR/torch_sparse/master</summary>
```
[------------------------------- cuda coo.index_select -------------------------------]
| PR | torch_sparse | master
32 threads: ---------------------------------------------------------------------------
n=10000, nnz=100, index_len=100, dim=0 | 96 | 327 | 70
n=10000, nnz=100, index_len=100, dim=1 | 120 | 505 | 74
n=10000, nnz=100, index_len=1000, dim=0 | 90 | 333 | 93
n=10000, nnz=100, index_len=1000, dim=1 | 120 | 499 | 98
n=10000, nnz=100, index_len=10000, dim=0 | 92 | 331 | 350
n=10000, nnz=100, index_len=10000, dim=1 | 100 | 506 | 352
n=100000, nnz=1000, index_len=100, dim=0 | 53 | 274 | 60
n=100000, nnz=1000, index_len=100, dim=1 | 90 | 368 | 71
n=100000, nnz=1000, index_len=1000, dim=0 | 93 | 332 | 100
n=100000, nnz=1000, index_len=1000, dim=1 | 130 | 501 | 140
n=100000, nnz=1000, index_len=10000, dim=0 | 100 | 341 | 522
n=100000, nnz=1000, index_len=10000, dim=1 | 130 | 530 | 549
n=1000000, nnz=10000, index_len=100, dim=0 | 90 | 429 | 110
n=1000000, nnz=10000, index_len=100, dim=1 | 296 | 810 | 355
n=1000000, nnz=10000, index_len=1000, dim=0 | 100 | 435 | 170
n=1000000, nnz=10000, index_len=1000, dim=1 | 309 | 830 | 548
n=1000000, nnz=10000, index_len=10000, dim=0 | 110 | 446 | 750
n=1000000, nnz=10000, index_len=10000, dim=1 | 310 | 830 | 1000
n=10, nnz=100, index_len=100, dim=0 | 90 | 333 | 74
n=10, nnz=100, index_len=100, dim=1 | 100 | 497 | 78
n=10, nnz=100, index_len=1000, dim=0 | 90 | 329 | 140
n=10, nnz=100, index_len=1000, dim=1 | 100 | 800 | 100
n=10, nnz=100, index_len=10000, dim=0 | 93 | 340 | 900
n=10, nnz=100, index_len=10000, dim=1 | 120 | 800 | 489
n=10, nnz=1000, index_len=100, dim=0 | 90 | 321 | 140
n=10, nnz=1000, index_len=100, dim=1 | 100 | 680 | 140
n=10, nnz=1000, index_len=1000, dim=0 | 110 | 349 | 670
n=10, nnz=1000, index_len=1000, dim=1 | 130 | 740 | 800
n=10, nnz=1000, index_len=10000, dim=0 | 302 | 503 | 4882
n=10, nnz=1000, index_len=10000, dim=1 | 325 | 2257 | 5262
n=10, nnz=10000, index_len=100, dim=0 | 229 | 349 | 810
n=10, nnz=10000, index_len=100, dim=1 | 433 | 870 | 700
n=10, nnz=10000, index_len=1000, dim=0 | 666 | 502 | 5581
n=10, nnz=10000, index_len=1000, dim=1 | 826 | 2379 | 4820
n=10, nnz=10000, index_len=10000, dim=0 | 2534 | 2700 | 80000
n=10, nnz=10000, index_len=10000, dim=1 | 2723 | 18540 | 80000
n=100, nnz=1000, index_len=100, dim=0 | 94 | 324 | 110
n=100, nnz=1000, index_len=100, dim=1 | 100 | 499 | 110
n=100, nnz=1000, index_len=1000, dim=0 | 96 | 337 | 150
n=100, nnz=1000, index_len=1000, dim=1 | 130 | 800 | 140
n=100, nnz=1000, index_len=10000, dim=0 | 100 | 346 | 900
n=100, nnz=1000, index_len=10000, dim=1 | 130 | 760 | 900
n=100, nnz=10000, index_len=100, dim=0 | 90 | 323 | 190
n=100, nnz=10000, index_len=100, dim=1 | 279 | 800 | 180
n=100, nnz=10000, index_len=1000, dim=0 | 110 | 339 | 781
n=100, nnz=10000, index_len=1000, dim=1 | 294 | 870 | 800
n=100, nnz=10000, index_len=10000, dim=0 | 315 | 505 | 6264
n=100, nnz=10000, index_len=10000, dim=1 | 497 | 2398 | 5404
n=1000, nnz=10000, index_len=100, dim=0 | 90 | 333 | 160
n=1000, nnz=10000, index_len=100, dim=1 | 279 | 635 | 150
n=1000, nnz=10000, index_len=1000, dim=0 | 100 | 328 | 215
n=1000, nnz=10000, index_len=1000, dim=1 | 287 | 810 | 207
n=1000, nnz=10000, index_len=10000, dim=0 | 100 | 339 | 900
n=1000, nnz=10000, index_len=10000, dim=1 | 291 | 880 | 1000
n=1000, nnz=100000, index_len=100, dim=0 | 92 | 358 | 435
n=1000, nnz=100000, index_len=100, dim=1 | 302 | 900 | 530
n=1000, nnz=100000, index_len=1000, dim=0 | 130 | 360 | 1000
n=1000, nnz=100000, index_len=1000, dim=1 | 329 | 930 | 1200
n=1000, nnz=100000, index_len=10000, dim=0 | 343 | 530 | 7000
n=1000, nnz=100000, index_len=10000, dim=1 | 545 | 2446 | 6100
n=1000, nnz=1000000, index_len=100, dim=0 | 355 | 394 | 2210
n=1000, nnz=1000000, index_len=100, dim=1 | 1660 | 2276 | 2674
n=1000, nnz=1000000, index_len=1000, dim=0 | 877 | 574 | 6700
n=1000, nnz=1000000, index_len=1000, dim=1 | 2449 | 3782 | 9000
n=1000, nnz=1000000, index_len=10000, dim=0 | 3112 | 2931 | 57000
n=1000, nnz=1000000, index_len=10000, dim=1 | 7340 | 20220 | 65700
Times are in microseconds (us).
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77551
Approved by: https://github.com/cpuhrsch