Migrate Embedding thrust sort to cub sort (#62495)
Summary:
This PR only migrates sort. Other thrust operations will be migrated in followup PRs
Benchmark `num_embeddings` pulled from https://github.com/huggingface/transformers/tree/master/examples by
```
grep -P 'vocab_size.*(=|:)\s*[0-9]+' -r transformers/examples/
grep -P 'hidden_size.*(=|:)\s*[0-9]+' -r transformers/examples/
```
to get `vocab_size = 119547, 50265, 32000, 8000, 3052` (similar size omitted) and `hidden_size = 512, 768`
Code:
```python
import torch
import itertools
num_embeddings = (119547, 50265, 32000, 8000, 3052)
num_tokens = (4096, 16384)
hidden_sizes = (512, 768)
for ne, nt, nh in itertools.product(num_embeddings, num_tokens, hidden_sizes):
print(f"Embedding size: {ne}, Tokens: {nt}, Hidden size: {nh}")
embedding = torch.nn.Embedding(ne, nh).cuda()
input_ = torch.randint(ne, (nt,), device='cuda')
out = embedding(input_)
torch.cuda.synchronize()
%timeit out.backward(out, retain_graph=True); torch.cuda.synchronize()
```
## On CUDA 11.3.1
Before:
```
Embedding size: 119547, Tokens: 4096, Hidden size: 512
1.43 ms ± 11.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 119547, Tokens: 4096, Hidden size: 768
2.07 ms ± 56.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Embedding size: 119547, Tokens: 16384, Hidden size: 512
1.61 ms ± 2.29 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 119547, Tokens: 16384, Hidden size: 768
2.32 ms ± 8.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Embedding size: 50265, Tokens: 4096, Hidden size: 512
738 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 4096, Hidden size: 768
1.02 ms ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 16384, Hidden size: 512
913 µs ± 3.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 16384, Hidden size: 768
1.27 ms ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 4096, Hidden size: 512
559 µs ± 860 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 4096, Hidden size: 768
743 µs ± 630 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 16384, Hidden size: 512
713 µs ± 969 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 16384, Hidden size: 768
977 µs ± 884 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 4096, Hidden size: 512
301 µs ± 8.02 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 4096, Hidden size: 768
383 µs ± 4.36 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 16384, Hidden size: 512
409 µs ± 1.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 16384, Hidden size: 768
515 µs ± 766 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 4096, Hidden size: 512
215 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 4096, Hidden size: 768
250 µs ± 320 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 16384, Hidden size: 512
271 µs ± 888 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 16384, Hidden size: 768
325 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
After:
```
Embedding size: 119547, Tokens: 4096, Hidden size: 512
1.42 ms ± 1.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 119547, Tokens: 4096, Hidden size: 768
2.05 ms ± 9.93 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Embedding size: 119547, Tokens: 16384, Hidden size: 512
1.6 ms ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 119547, Tokens: 16384, Hidden size: 768
2.3 ms ± 3.67 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Embedding size: 50265, Tokens: 4096, Hidden size: 512
730 µs ± 811 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 4096, Hidden size: 768
1.01 ms ± 2.71 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 16384, Hidden size: 512
887 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 16384, Hidden size: 768
1.25 ms ± 2.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 4096, Hidden size: 512
556 µs ± 1.86 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 4096, Hidden size: 768
744 µs ± 4.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 16384, Hidden size: 512
691 µs ± 570 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 16384, Hidden size: 768
957 µs ± 2.02 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 4096, Hidden size: 512
309 µs ± 2.84 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 4096, Hidden size: 768
376 µs ± 2.18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 16384, Hidden size: 512
381 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 16384, Hidden size: 768
487 µs ± 2.42 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 4096, Hidden size: 512
202 µs ± 383 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 4096, Hidden size: 768
239 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 16384, Hidden size: 512
243 µs ± 1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 16384, Hidden size: 768
340 µs ± 2.28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
## On CUDA 11.1
Before:
```
Embedding size: 119547, Tokens: 4096, Hidden size: 512
1.41 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 119547, Tokens: 4096, Hidden size: 768
2.05 ms ± 7.61 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Embedding size: 119547, Tokens: 16384, Hidden size: 512
1.61 ms ± 1.95 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 119547, Tokens: 16384, Hidden size: 768
2.32 ms ± 2.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Embedding size: 50265, Tokens: 4096, Hidden size: 512
743 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 4096, Hidden size: 768
1.02 ms ± 2.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 16384, Hidden size: 512
912 µs ± 5.91 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 16384, Hidden size: 768
1.28 ms ± 6.17 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 4096, Hidden size: 512
555 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 4096, Hidden size: 768
743 µs ± 655 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 16384, Hidden size: 512
714 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 16384, Hidden size: 768
980 µs ± 1.52 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 4096, Hidden size: 512
312 µs ± 396 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 4096, Hidden size: 768
386 µs ± 2.32 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 16384, Hidden size: 512
413 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 16384, Hidden size: 768
512 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 4096, Hidden size: 512
209 µs ± 585 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 4096, Hidden size: 768
271 µs ± 776 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 16384, Hidden size: 512
297 µs ± 1.11 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 16384, Hidden size: 768
377 µs ± 3.87 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
After:
```
Embedding size: 119547, Tokens: 4096, Hidden size: 512
1.46 ms ± 12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 119547, Tokens: 4096, Hidden size: 768
2.09 ms ± 4.31 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Embedding size: 119547, Tokens: 16384, Hidden size: 512
1.64 ms ± 4.48 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 119547, Tokens: 16384, Hidden size: 768
2.35 ms ± 2.54 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Embedding size: 50265, Tokens: 4096, Hidden size: 512
782 µs ± 2.12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 4096, Hidden size: 768
1.06 ms ± 596 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 16384, Hidden size: 512
945 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 50265, Tokens: 16384, Hidden size: 768
1.31 ms ± 553 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 4096, Hidden size: 512
603 µs ± 856 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 4096, Hidden size: 768
789 µs ± 500 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 16384, Hidden size: 512
752 µs ± 7.56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 32000, Tokens: 16384, Hidden size: 768
1.01 ms ± 4.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 4096, Hidden size: 512
323 µs ± 7.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 4096, Hidden size: 768
398 µs ± 765 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 16384, Hidden size: 512
412 µs ± 544 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 8000, Tokens: 16384, Hidden size: 768
519 µs ± 614 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 4096, Hidden size: 512
229 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 4096, Hidden size: 768
263 µs ± 417 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 16384, Hidden size: 512
274 µs ± 576 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Embedding size: 3052, Tokens: 16384, Hidden size: 768
354 µs ± 1.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62495
Reviewed By: gchanan
Differential Revision: D30176833
Pulled By: ngimel
fbshipit-source-id: 44148ebb53a0abfc1e5ab8b986865555bf326ad1