pytorch
8f953ef5 - Increase token count threshold for calling thrust sort in embedding backward (#49913)

Commit
3 years ago
Increase token count threshold for calling thrust sort in embedding backward (#49913) Summary: Increases the token count threshold to expand the span of custom CUDA kernel implementation of embedding backward. Here is the speedup for embedding backward implementation for DGXV100-128GB and DGXA100-640GB given below. I picked 6144 as the new threshold since anything below it mostly results in faster execution with custom CUDA kernel. One important advantage of the custom CUDA kernel is that it allows CUDA graph capture, whereas thrust code path results in CPU syncs, prohibiting graph capture (times below are collected without graph capture). For reference, MLPerf BERT benchmark uses num_features=1024. | num_tokens | num_features | thrust path(ms) | custom kernel(ms) | speedup -- | -- | -- | -- | -- | -- DGXV100 | | | | | | 1024 | 64 | 0.36 | 0.18 | 2.04 | 1024 | 256 | 0.43 | 0.30 | 1.46 | 1024 | 1024 | 0.89 | 0.74 | 1.20 | 1024 | 2048 | 1.50 | 1.33 | 1.12 | 1024 | 4096 | 2.71 | 2.50 | 1.08 | 1024 | 8192 | 5.07 | 4.89 | 1.04 | 2048 | 64 | 0.33 | 0.23 | 1.46 | 2048 | 256 | 0.41 | 0.33 | 1.26 | 2048 | 1024 | 0.92 | 0.79 | 1.17 | 2048 | 2048 | 1.54 | 1.38 | 1.11 | 2048 | 4096 | 2.80 | 2.54 | 1.10 | 2048 | 8192 | 5.29 | 4.98 | 1.06 | 4096 | 64 | 0.46 | 0.32 | 1.43 | 4096 | 256 | 0.50 | 0.47 | 1.07 | 4096 | 1024 | 1.02 | 0.88 | 1.15 | 4096 | 2048 | 1.70 | 1.59 | 1.07 | 4096 | 4096 | 3.06 | 2.68 | 1.14 | 4096 | 8192 | 5.79 | 5.28 | 1.10 | 5120 | 64 | 0.42 | 0.33 | 1.28 | 5120 | 256 | 0.51 | 0.46 | 1.11 | 5120 | 1024 | 1.06 | 0.93 | 1.14 | 5120 | 2048 | 1.77 | 1.55 | 1.14 | 5120 | 4096 | 3.18 | 2.76 | 1.15 | 5120 | 8192 | 6.24 | 5.46 | 1.14 | 6144 | 64 | 0.42 | 0.36 | 1.17 | 6144 | 256 | 0.52 | 0.50 | 1.05 | 6144 | 1024 | 1.10 | 0.98 | 1.13 | 6144 | 2048 | 1.85 | 1.61 | 1.15 | 6144 | 4096 | 3.34 | 2.84 | 1.17 | 6144 | 8192 | 6.19 | 5.69 | 1.09 | 8192 | 64 | 0.42 | 0.48 | 0.88 | 8192 | 256 | 0.51 | 0.65 | 0.78 | 8192 | 1024 | 1.14 | 1.12 | 1.01 | 8192 | 2048 | 1.92 | 1.77 | 1.09 | 8192 | 4096 | 3.49 | 3.03 | 1.15 | 8192 | 8192 | 6.59 | 5.96 | 1.11 | 16384 | 64 | 0.46 | 0.82 | 0.56 | 16384 | 256 | 0.59 | 0.99 | 0.60 | 16384 | 1024 | 1.35 | 1.54 | 0.88 | 16384 | 2048 | 2.31 | 2.24 | 1.03 | 16384 | 4096 | 4.20 | 3.63 | 1.16 | 16384 | 8192 | 8.26 | 7.51 | 1.10 | 32768 | 64 | 0.47 | 1.48 | 0.32 | 32768 | 256 | 0.68 | 1.70 | 0.40 | 32768 | 1024 | 1.63 | 2.35 | 0.69 | 32768 | 2048 | 2.87 | 3.19 | 0.90 | 32768 | 4096 | 5.26 | 4.86 | 1.08 | 32768 | 8192 | 10.17 | 9.92 | 1.03 | 65536 | 64 | 0.50 | 2.81 | 0.18 | 65536 | 256 | 0.78 | 3.12 | 0.25 | 65536 | 1024 | 2.02 | 3.99 | 0.51 | 65536 | 2048 | 3.58 | 5.06 | 0.71 | 65536 | 4096 | 6.68 | 7.40 | 0.90 | 65536 | 8192 | 13.08 | 15.35 | 0.85 DGXA100 | | | | | | 1024 | 64 | 0.28 | 0.09 | 3.05 | 1024 | 256 | 0.30 | 0.17 | 1.71 | 1024 | 1024 | 0.51 | 0.39 | 1.31 | 1024 | 2048 | 0.81 | 0.68 | 1.20 | 1024 | 4096 | 1.43 | 1.24 | 1.16 | 1024 | 8192 | 2.63 | 2.42 | 1.09 | 2048 | 64 | 0.25 | 0.12 | 2.15 | 2048 | 256 | 0.29 | 0.22 | 1.36 | 2048 | 1024 | 0.53 | 0.44 | 1.20 | 2048 | 2048 | 0.86 | 0.73 | 1.18 | 2048 | 4096 | 1.51 | 1.30 | 1.16 | 2048 | 8192 | 2.81 | 2.55 | 1.10 | 4096 | 64 | 0.31 | 0.20 | 1.57 | 4096 | 256 | 0.35 | 0.33 | 1.08 | 4096 | 1024 | 0.63 | 0.57 | 1.10 | 4096 | 2048 | 1.08 | 0.86 | 1.26 | 4096 | 4096 | 2.11 | 1.44 | 1.46 | 4096 | 8192 | 3.33 | 2.81 | 1.19 | 5120 | 64 | 0.36 | 0.22 | 1.63 | 5120 | 256 | 0.37 | 0.37 | 0.98 | 5120 | 1024 | 0.66 | 0.62 | 1.07 | 5120 | 2048 | 1.05 | 0.92 | 1.15 | 5120 | 4096 | 1.83 | 1.51 | 1.21 | 5120 | 8192 | 3.35 | 2.94 | 1.14 | 6144 | 64 | 0.29 | 0.25 | 1.18 | 6144 | 256 | 0.37 | 0.43 | 0.86 | 6144 | 1024 | 0.70 | 0.68 | 1.03 | 6144 | 2048 | 1.08 | 0.98 | 1.11 | 6144 | 4096 | 1.89 | 1.57 | 1.20 | 6144 | 8192 | 3.49 | 3.07 | 1.14 | 8192 | 64 | 0.29 | 0.31 | 0.95 | 8192 | 256 | 0.37 | 0.52 | 0.70 | 8192 | 1024 | 0.71 | 0.79 | 0.90 | 8192 | 2048 | 1.16 | 1.10 | 1.06 | 8192 | 4096 | 2.04 | 1.70 | 1.20 | 8192 | 8192 | 3.86 | 3.32 | 1.16 | 16384 | 64 | 0.31 | 0.55 | 0.56 | 16384 | 256 | 0.42 | 0.93 | 0.45 | 16384 | 1024 | 0.87 | 1.24 | 0.70 | 16384 | 2048 | 1.46 | 1.57 | 0.93 | 16384 | 4096 | 2.60 | 2.23 | 1.17 | 16384 | 8192 | 5.15 | 4.69 | 1.10 | 32768 | 64 | 0.33 | 1.03 | 0.32 | 32768 | 256 | 0.49 | 1.78 | 0.28 | 32768 | 1024 | 1.11 | 2.18 | 0.51 | 32768 | 2048 | 1.90 | 2.54 | 0.75 | 32768 | 4096 | 3.45 | 3.31 | 1.04 | 32768 | 8192 | 6.46 | 6.43 | 1.00 | 65536 | 64 | 0.36 | 2.19 | 0.16 | 65536 | 256 | 0.56 | 3.41 | 0.17 | 65536 | 1024 | 1.39 | 4.01 | 0.35 | 65536 | 2048 | 2.48 | 4.45 | 0.56 | 65536 | 4096 | 4.50 | 5.44 | 0.83 | 65536 | 8192 | 8.49 | 10.55 | 0.80 Here is the script used to generate the times (30522 is used in BERT MLPerf benchmark as vocabulary size, hence is used in this example): ``` import torch import torch.nn as nn import time vocabulary_size = 30522 for num_tokens in [512,1024,2048,4096,5120,6144,8192,16384,32768,65536]: for hidden_dim in [64,256,1024,2048,4096,8192]: fprop_time_avg = 0.0 bprop_time_avg = 0.0 emb = nn.Embedding(vocabulary_size, hidden_dim).cuda() for trial in range(0,10): inds = torch.round(torch.rand(num_tokens) * (vocabulary_size-1)).to(dtype=torch.int64).cuda() y = emb(inds) dy = torch.randn_like(y) torch.cuda.synchronize() t_start_bwd = time.time() y.backward(dy) torch.cuda.synchronize() t_stop_bwd = time.time() bprop_time_avg += t_stop_bwd - t_start_bwd bprop_time_avg /= 10.0 print("bprop num_tokens %5d, num_features %5d, time %2.6f" %(num_tokens, hidden_dim, bprop_time_avg)) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/49913 Reviewed By: jbschlosser Differential Revision: D27727738 Pulled By: ngimel fbshipit-source-id: fa497b6745b6d20bb11352579ed9eb5b66a8b1e2
Author
Sukru Eryilmaz
Parents
Loading