pytorch
a2c4b18b - Migrate Embedding thrust sort to cub sort (#62495)

Commit
3 years ago
Migrate Embedding thrust sort to cub sort (#62495) Summary: This PR only migrates sort. Other thrust operations will be migrated in followup PRs Benchmark `num_embeddings` pulled from https://github.com/huggingface/transformers/tree/master/examples by ``` grep -P 'vocab_size.*(=|:)\s*[0-9]+' -r transformers/examples/ grep -P 'hidden_size.*(=|:)\s*[0-9]+' -r transformers/examples/ ``` to get `vocab_size = 119547, 50265, 32000, 8000, 3052` (similar size omitted) and `hidden_size = 512, 768` Code: ```python import torch import itertools num_embeddings = (119547, 50265, 32000, 8000, 3052) num_tokens = (4096, 16384) hidden_sizes = (512, 768) for ne, nt, nh in itertools.product(num_embeddings, num_tokens, hidden_sizes): print(f"Embedding size: {ne}, Tokens: {nt}, Hidden size: {nh}") embedding = torch.nn.Embedding(ne, nh).cuda() input_ = torch.randint(ne, (nt,), device='cuda') out = embedding(input_) torch.cuda.synchronize() %timeit out.backward(out, retain_graph=True); torch.cuda.synchronize() ``` ## On CUDA 11.3.1 Before: ``` Embedding size: 119547, Tokens: 4096, Hidden size: 512 1.43 ms ± 11.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 119547, Tokens: 4096, Hidden size: 768 2.07 ms ± 56.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Embedding size: 119547, Tokens: 16384, Hidden size: 512 1.61 ms ± 2.29 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 119547, Tokens: 16384, Hidden size: 768 2.32 ms ± 8.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Embedding size: 50265, Tokens: 4096, Hidden size: 512 738 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 4096, Hidden size: 768 1.02 ms ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 16384, Hidden size: 512 913 µs ± 3.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 16384, Hidden size: 768 1.27 ms ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 4096, Hidden size: 512 559 µs ± 860 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 4096, Hidden size: 768 743 µs ± 630 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 16384, Hidden size: 512 713 µs ± 969 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 16384, Hidden size: 768 977 µs ± 884 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 4096, Hidden size: 512 301 µs ± 8.02 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 4096, Hidden size: 768 383 µs ± 4.36 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 16384, Hidden size: 512 409 µs ± 1.39 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 16384, Hidden size: 768 515 µs ± 766 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 4096, Hidden size: 512 215 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 4096, Hidden size: 768 250 µs ± 320 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 16384, Hidden size: 512 271 µs ± 888 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 16384, Hidden size: 768 325 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` After: ``` Embedding size: 119547, Tokens: 4096, Hidden size: 512 1.42 ms ± 1.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 119547, Tokens: 4096, Hidden size: 768 2.05 ms ± 9.93 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Embedding size: 119547, Tokens: 16384, Hidden size: 512 1.6 ms ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 119547, Tokens: 16384, Hidden size: 768 2.3 ms ± 3.67 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Embedding size: 50265, Tokens: 4096, Hidden size: 512 730 µs ± 811 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 4096, Hidden size: 768 1.01 ms ± 2.71 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 16384, Hidden size: 512 887 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 16384, Hidden size: 768 1.25 ms ± 2.74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 4096, Hidden size: 512 556 µs ± 1.86 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 4096, Hidden size: 768 744 µs ± 4.44 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 16384, Hidden size: 512 691 µs ± 570 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 16384, Hidden size: 768 957 µs ± 2.02 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 4096, Hidden size: 512 309 µs ± 2.84 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 4096, Hidden size: 768 376 µs ± 2.18 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 16384, Hidden size: 512 381 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 16384, Hidden size: 768 487 µs ± 2.42 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 4096, Hidden size: 512 202 µs ± 383 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 4096, Hidden size: 768 239 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 16384, Hidden size: 512 243 µs ± 1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 16384, Hidden size: 768 340 µs ± 2.28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` ## On CUDA 11.1 Before: ``` Embedding size: 119547, Tokens: 4096, Hidden size: 512 1.41 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 119547, Tokens: 4096, Hidden size: 768 2.05 ms ± 7.61 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Embedding size: 119547, Tokens: 16384, Hidden size: 512 1.61 ms ± 1.95 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 119547, Tokens: 16384, Hidden size: 768 2.32 ms ± 2.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Embedding size: 50265, Tokens: 4096, Hidden size: 512 743 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 4096, Hidden size: 768 1.02 ms ± 2.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 16384, Hidden size: 512 912 µs ± 5.91 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 16384, Hidden size: 768 1.28 ms ± 6.17 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 4096, Hidden size: 512 555 µs ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 4096, Hidden size: 768 743 µs ± 655 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 16384, Hidden size: 512 714 µs ± 1.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 16384, Hidden size: 768 980 µs ± 1.52 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 4096, Hidden size: 512 312 µs ± 396 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 4096, Hidden size: 768 386 µs ± 2.32 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 16384, Hidden size: 512 413 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 16384, Hidden size: 768 512 µs ± 1.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 4096, Hidden size: 512 209 µs ± 585 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 4096, Hidden size: 768 271 µs ± 776 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 16384, Hidden size: 512 297 µs ± 1.11 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 16384, Hidden size: 768 377 µs ± 3.87 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` After: ``` Embedding size: 119547, Tokens: 4096, Hidden size: 512 1.46 ms ± 12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 119547, Tokens: 4096, Hidden size: 768 2.09 ms ± 4.31 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Embedding size: 119547, Tokens: 16384, Hidden size: 512 1.64 ms ± 4.48 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 119547, Tokens: 16384, Hidden size: 768 2.35 ms ± 2.54 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Embedding size: 50265, Tokens: 4096, Hidden size: 512 782 µs ± 2.12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 4096, Hidden size: 768 1.06 ms ± 596 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 16384, Hidden size: 512 945 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 50265, Tokens: 16384, Hidden size: 768 1.31 ms ± 553 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 4096, Hidden size: 512 603 µs ± 856 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 4096, Hidden size: 768 789 µs ± 500 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 16384, Hidden size: 512 752 µs ± 7.56 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 32000, Tokens: 16384, Hidden size: 768 1.01 ms ± 4.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 4096, Hidden size: 512 323 µs ± 7.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 4096, Hidden size: 768 398 µs ± 765 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 16384, Hidden size: 512 412 µs ± 544 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 8000, Tokens: 16384, Hidden size: 768 519 µs ± 614 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 4096, Hidden size: 512 229 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 4096, Hidden size: 768 263 µs ± 417 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 16384, Hidden size: 512 274 µs ± 576 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) Embedding size: 3052, Tokens: 16384, Hidden size: 768 354 µs ± 1.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62495 Reviewed By: gchanan Differential Revision: D30176833 Pulled By: ngimel fbshipit-source-id: 44148ebb53a0abfc1e5ab8b986865555bf326ad1
Author
Committer
Parents
Loading