Increase token count threshold for calling thrust sort in embedding backward (#49913)
Summary:
Increases the token count threshold to expand the span of custom CUDA kernel implementation of embedding backward. Here is the speedup for embedding backward implementation for DGXV100-128GB and DGXA100-640GB given below. I picked 6144 as the new threshold since anything below it mostly results in faster execution with custom CUDA kernel. One important advantage of the custom CUDA kernel is that it allows CUDA graph capture, whereas thrust code path results in CPU syncs, prohibiting graph capture (times below are collected without graph capture). For reference, MLPerf BERT benchmark uses num_features=1024.
| num_tokens | num_features | thrust path(ms) | custom kernel(ms) | speedup
-- | -- | -- | -- | -- | --
DGXV100 | | | | |
| 1024 | 64 | 0.36 | 0.18 | 2.04
| 1024 | 256 | 0.43 | 0.30 | 1.46
| 1024 | 1024 | 0.89 | 0.74 | 1.20
| 1024 | 2048 | 1.50 | 1.33 | 1.12
| 1024 | 4096 | 2.71 | 2.50 | 1.08
| 1024 | 8192 | 5.07 | 4.89 | 1.04
| 2048 | 64 | 0.33 | 0.23 | 1.46
| 2048 | 256 | 0.41 | 0.33 | 1.26
| 2048 | 1024 | 0.92 | 0.79 | 1.17
| 2048 | 2048 | 1.54 | 1.38 | 1.11
| 2048 | 4096 | 2.80 | 2.54 | 1.10
| 2048 | 8192 | 5.29 | 4.98 | 1.06
| 4096 | 64 | 0.46 | 0.32 | 1.43
| 4096 | 256 | 0.50 | 0.47 | 1.07
| 4096 | 1024 | 1.02 | 0.88 | 1.15
| 4096 | 2048 | 1.70 | 1.59 | 1.07
| 4096 | 4096 | 3.06 | 2.68 | 1.14
| 4096 | 8192 | 5.79 | 5.28 | 1.10
| 5120 | 64 | 0.42 | 0.33 | 1.28
| 5120 | 256 | 0.51 | 0.46 | 1.11
| 5120 | 1024 | 1.06 | 0.93 | 1.14
| 5120 | 2048 | 1.77 | 1.55 | 1.14
| 5120 | 4096 | 3.18 | 2.76 | 1.15
| 5120 | 8192 | 6.24 | 5.46 | 1.14
| 6144 | 64 | 0.42 | 0.36 | 1.17
| 6144 | 256 | 0.52 | 0.50 | 1.05
| 6144 | 1024 | 1.10 | 0.98 | 1.13
| 6144 | 2048 | 1.85 | 1.61 | 1.15
| 6144 | 4096 | 3.34 | 2.84 | 1.17
| 6144 | 8192 | 6.19 | 5.69 | 1.09
| 8192 | 64 | 0.42 | 0.48 | 0.88
| 8192 | 256 | 0.51 | 0.65 | 0.78
| 8192 | 1024 | 1.14 | 1.12 | 1.01
| 8192 | 2048 | 1.92 | 1.77 | 1.09
| 8192 | 4096 | 3.49 | 3.03 | 1.15
| 8192 | 8192 | 6.59 | 5.96 | 1.11
| 16384 | 64 | 0.46 | 0.82 | 0.56
| 16384 | 256 | 0.59 | 0.99 | 0.60
| 16384 | 1024 | 1.35 | 1.54 | 0.88
| 16384 | 2048 | 2.31 | 2.24 | 1.03
| 16384 | 4096 | 4.20 | 3.63 | 1.16
| 16384 | 8192 | 8.26 | 7.51 | 1.10
| 32768 | 64 | 0.47 | 1.48 | 0.32
| 32768 | 256 | 0.68 | 1.70 | 0.40
| 32768 | 1024 | 1.63 | 2.35 | 0.69
| 32768 | 2048 | 2.87 | 3.19 | 0.90
| 32768 | 4096 | 5.26 | 4.86 | 1.08
| 32768 | 8192 | 10.17 | 9.92 | 1.03
| 65536 | 64 | 0.50 | 2.81 | 0.18
| 65536 | 256 | 0.78 | 3.12 | 0.25
| 65536 | 1024 | 2.02 | 3.99 | 0.51
| 65536 | 2048 | 3.58 | 5.06 | 0.71
| 65536 | 4096 | 6.68 | 7.40 | 0.90
| 65536 | 8192 | 13.08 | 15.35 | 0.85
DGXA100 | | | | |
| 1024 | 64 | 0.28 | 0.09 | 3.05
| 1024 | 256 | 0.30 | 0.17 | 1.71
| 1024 | 1024 | 0.51 | 0.39 | 1.31
| 1024 | 2048 | 0.81 | 0.68 | 1.20
| 1024 | 4096 | 1.43 | 1.24 | 1.16
| 1024 | 8192 | 2.63 | 2.42 | 1.09
| 2048 | 64 | 0.25 | 0.12 | 2.15
| 2048 | 256 | 0.29 | 0.22 | 1.36
| 2048 | 1024 | 0.53 | 0.44 | 1.20
| 2048 | 2048 | 0.86 | 0.73 | 1.18
| 2048 | 4096 | 1.51 | 1.30 | 1.16
| 2048 | 8192 | 2.81 | 2.55 | 1.10
| 4096 | 64 | 0.31 | 0.20 | 1.57
| 4096 | 256 | 0.35 | 0.33 | 1.08
| 4096 | 1024 | 0.63 | 0.57 | 1.10
| 4096 | 2048 | 1.08 | 0.86 | 1.26
| 4096 | 4096 | 2.11 | 1.44 | 1.46
| 4096 | 8192 | 3.33 | 2.81 | 1.19
| 5120 | 64 | 0.36 | 0.22 | 1.63
| 5120 | 256 | 0.37 | 0.37 | 0.98
| 5120 | 1024 | 0.66 | 0.62 | 1.07
| 5120 | 2048 | 1.05 | 0.92 | 1.15
| 5120 | 4096 | 1.83 | 1.51 | 1.21
| 5120 | 8192 | 3.35 | 2.94 | 1.14
| 6144 | 64 | 0.29 | 0.25 | 1.18
| 6144 | 256 | 0.37 | 0.43 | 0.86
| 6144 | 1024 | 0.70 | 0.68 | 1.03
| 6144 | 2048 | 1.08 | 0.98 | 1.11
| 6144 | 4096 | 1.89 | 1.57 | 1.20
| 6144 | 8192 | 3.49 | 3.07 | 1.14
| 8192 | 64 | 0.29 | 0.31 | 0.95
| 8192 | 256 | 0.37 | 0.52 | 0.70
| 8192 | 1024 | 0.71 | 0.79 | 0.90
| 8192 | 2048 | 1.16 | 1.10 | 1.06
| 8192 | 4096 | 2.04 | 1.70 | 1.20
| 8192 | 8192 | 3.86 | 3.32 | 1.16
| 16384 | 64 | 0.31 | 0.55 | 0.56
| 16384 | 256 | 0.42 | 0.93 | 0.45
| 16384 | 1024 | 0.87 | 1.24 | 0.70
| 16384 | 2048 | 1.46 | 1.57 | 0.93
| 16384 | 4096 | 2.60 | 2.23 | 1.17
| 16384 | 8192 | 5.15 | 4.69 | 1.10
| 32768 | 64 | 0.33 | 1.03 | 0.32
| 32768 | 256 | 0.49 | 1.78 | 0.28
| 32768 | 1024 | 1.11 | 2.18 | 0.51
| 32768 | 2048 | 1.90 | 2.54 | 0.75
| 32768 | 4096 | 3.45 | 3.31 | 1.04
| 32768 | 8192 | 6.46 | 6.43 | 1.00
| 65536 | 64 | 0.36 | 2.19 | 0.16
| 65536 | 256 | 0.56 | 3.41 | 0.17
| 65536 | 1024 | 1.39 | 4.01 | 0.35
| 65536 | 2048 | 2.48 | 4.45 | 0.56
| 65536 | 4096 | 4.50 | 5.44 | 0.83
| 65536 | 8192 | 8.49 | 10.55 | 0.80
Here is the script used to generate the times (30522 is used in BERT MLPerf benchmark as vocabulary size, hence is used in this example):
```
import torch
import torch.nn as nn
import time
vocabulary_size = 30522
for num_tokens in [512,1024,2048,4096,5120,6144,8192,16384,32768,65536]:
for hidden_dim in [64,256,1024,2048,4096,8192]:
fprop_time_avg = 0.0
bprop_time_avg = 0.0
emb = nn.Embedding(vocabulary_size, hidden_dim).cuda()
for trial in range(0,10):
inds = torch.round(torch.rand(num_tokens) * (vocabulary_size-1)).to(dtype=torch.int64).cuda()
y = emb(inds)
dy = torch.randn_like(y)
torch.cuda.synchronize()
t_start_bwd = time.time()
y.backward(dy)
torch.cuda.synchronize()
t_stop_bwd = time.time()
bprop_time_avg += t_stop_bwd - t_start_bwd
bprop_time_avg /= 10.0
print("bprop num_tokens %5d, num_features %5d, time %2.6f" %(num_tokens, hidden_dim, bprop_time_avg))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49913
Reviewed By: jbschlosser
Differential Revision: D27727738
Pulled By: ngimel
fbshipit-source-id: fa497b6745b6d20bb11352579ed9eb5b66a8b1e2