[cuda] vectorized implementation for layer_norm_grad_input_kernel (#111021)
Using vectorized loads/stores makes the `layer_norm_grad_input_kernel` generally faster. This PR accelerates medium and larger problem sizes.
```python
def run_model_on_device(fs, X, gO, device_string, numeric_type):
ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
ln.reset_parameters()
X.grad = None
ln.zero_grad(set_to_none=True)
out = ln(X)
out.backward(gO)
return (ln.weight.grad, ln.bias.grad)
def run_correctness_test(eps_weight, eps_bias):
dtype = torch.float
for val in l_inputs:
bs = val[0][0]
fs = val[0][1]
mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
X = mean_adjustment * torch.randn(
bs, fs, device="cpu", dtype=torch.float, requires_grad=True
)
X = X.detach().requires_grad_()
gO = torch.rand_like(X)
X_gpu = X.to("cuda")
X_gpu = X_gpu.detach().requires_grad_()
gO_gpu = gO.to("cuda")
gO_gpu = gO_gpu.detach().requires_grad_()
grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")
weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
weight_mismatches = (weight_delta >= eps_weight).nonzero()
weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100
bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
bias_mismatches = (bias_delta >= eps_bias).nonzero()
bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100
if weight_mismatch_pct > 0 or bias_mismatch_pct > 0:
print(
"Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
fs, bs, weight_mismatch_pct, bias_mismatch_pct
)
)
# Run the correctness tests
run_correctness_test(0.01, 0.01)
torch.cuda.synchronize()
# Allocate a tensor equal to L2 cache size on A100 GPUs
l2_cache_flusher = torch.empty(int(80 * (1024**2)), dtype=torch.float, device="cuda")
# Run the performance tests. We need to run this at global scope because otherwise
# the `ln` and `gO` objects are likely removed by the JIT compiler
results = []
for dtype in (torch.float, torch.half):
for val in l_inputs:
bs = val[0][0]
fs = val[0][1]
iterations = val[1]
ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
gO = torch.rand_like(X)
# Try to measure FWD and BWD pass in the same loop
l_ev_start_fwd = [torch.cuda.Event(enable_timing=True)] * iterations
l_ev_stop_fwd = [torch.cuda.Event(enable_timing=True)] * iterations
l_ev_stop_bwd = [torch.cuda.Event(enable_timing=True)] * iterations
l_fwd_times = []
l_bwd_times = []
torch.cuda.synchronize()
for i in range(iterations):
l2_cache_flusher.zero_()
torch.cuda._sleep(1_000_000)
X.grad = None
ln.zero_grad(set_to_none=True)
l_ev_start_fwd[i].record()
out = ln(X)
l_ev_stop_fwd[i].record()
out.backward(gO)
l_ev_stop_bwd[i].record()
torch.cuda.synchronize()
l_fwd_times = []
l_bwd_times = []
for i in range(iterations):
l_fwd_times.append(l_ev_start_fwd[i].elapsed_time(l_ev_stop_fwd[i]))
l_bwd_times.append(l_ev_stop_fwd[i].elapsed_time(l_ev_stop_bwd[i]))
print(
"({}, {}, {}, fwd_ms, bwd_ms)|{:.3f}|{:.3f}".format(
dtype,
bs,
fs,
sum(l_fwd_times) / iterations * 1000,
sum(l_bwd_times) / iterations * 1000,
)
)
```
Results in the attached picture:
<img width="314" alt="Screenshot 2023-10-16 at 11 08 25 AM" src="https://github.com/pytorch/pytorch/assets/23515689/ce571fc5-c84e-47eb-95f6-9faa44042cc1">
I also isolated the previous implementation and the vectorized one into a native CUDA program and the speedup is confirmed. **Average speedup = 21.73%**
```
Size (2048, 2048); Mismatches: dX = 0 out of 4194304. Max missmatch idx = 0. [16/1529]
reference = 0.0560 (ms); optimized = 0.0435 (ms); bw_opt = 1437.54 GB/s; speedup = 28.78%
Size (4096, 512); Mismatches: dX = 0 out of 2097152. Max missmatch idx = 0.
reference = 0.0220 (ms); optimized = 0.0174 (ms); bw_opt = 1797.26 GB/s; speedup = 26.44%
Size (1024, 512); Mismatches: dX = 0 out of 524288. Max missmatch idx = 0.
reference = 0.0101 (ms); optimized = 0.0082 (ms); bw_opt = 953.49 GB/s; speedup = 22.97%
Size (1024, 256); Mismatches: dX = 1 out of 262144. Max missmatch idx = 22411.
reference = 0.0082 (ms); optimized = 0.0075 (ms); bw_opt = 521.14 GB/s; speedup = 9.21%
Size (1024, 1024); Mismatches: dX = 0 out of 1048576. Max missmatch idx = 0.
reference = 0.0137 (ms); optimized = 0.0108 (ms); bw_opt = 1447.42 GB/s; speedup = 26.93%
Size (2048, 512); Mismatches: dX = 0 out of 1048576. Max missmatch idx = 0.
reference = 0.0141 (ms); optimized = 0.0116 (ms); bw_opt = 1349.79 GB/s; speedup = 21.81%
Size (2048, 256); Mismatches: dX = 0 out of 524288. Max missmatch idx = 0.
reference = 0.0108 (ms); optimized = 0.0102 (ms); bw_opt = 768.90 GB/s; speedup = 6.09%
Size (1024, 128); Mismatches: dX = 1 out of 131072. Max missmatch idx = 9165.
reference = 0.0070 (ms); optimized = 0.0068 (ms); bw_opt = 288.56 GB/s; speedup = 2.81%
Size (1024, 2048); Mismatches: dX = 0 out of 2097152. Max missmatch idx = 0.
reference = 0.0223 (ms); optimized = 0.0164 (ms); bw_opt = 1905.58 GB/s; speedup = 35.90%
Size (1024, 768); Mismatches: dX = 3 out of 786432. Max missmatch idx = 507105.
reference = 0.0113 (ms); optimized = 0.0101 (ms); bw_opt = 1160.00 GB/s; speedup = 11.79%
Size (2048, 128); Mismatches: dX = 0 out of 262144. Max missmatch idx = 0.
reference = 0.0097 (ms); optimized = 0.0089 (ms); bw_opt = 440.97 GB/s; speedup = 9.12%
Size (2048, 1024); Mismatches: dX = 0 out of 2097152. Max missmatch idx = 0.
reference = 0.0204 (ms); optimized = 0.0166 (ms); bw_opt = 1881.43 GB/s; speedup = 22.81%
Size (4096, 256); Mismatches: dX = 1 out of 1048576. Max missmatch idx = 601965.
reference = 0.0156 (ms); optimized = 0.0154 (ms); bw_opt = 1016.47 GB/s; speedup = 1.24%
Size (4096, 1024); Mismatches: dX = 0 out of 4194304. Max missmatch idx = 0.
reference = 0.0411 (ms); optimized = 0.0417 (ms); bw_opt = 1499.55 GB/s; speedup = -1.43%
Size (4096, 4096); Mismatches: dX = 0 out of 16777216. Max missmatch idx = 0.
reference = 0.2323 (ms); optimized = 0.2077 (ms); bw_opt = 1203.75 GB/s; speedup = 11.83%
Size (1024, 4096); Mismatches: dX = 0 out of 4194304. Max missmatch idx = 0.
reference = 0.0659 (ms); optimized = 0.0570 (ms); bw_opt = 1096.51 GB/s; speedup = 15.60%
Size (1024, 3072); Mismatches: dX = 0 out of 3145728. Max missmatch idx = 0.
reference = 0.0425 (ms); optimized = 0.0299 (ms); bw_opt = 1568.10 GB/s; speedup = 42.11%
Size (1024, 2464); Mismatches: dX = 8 out of 2523136. Max missmatch idx = 2087476.
reference = 0.0292 (ms); optimized = 0.0230 (ms); bw_opt = 1636.18 GB/s; speedup = 27.07%
Size (1024, 800); Mismatches: dX = 1 out of 819200. Max missmatch idx = 652342.
reference = 0.0114 (ms); optimized = 0.0104 (ms); bw_opt = 1175.05 GB/s; speedup = 9.63%
Size (1024, 6144); Mismatches: dX = 0 out of 6291456. Max missmatch idx = 0.
reference = 0.0973 (ms); optimized = 0.0844 (ms); bw_opt = 1110.87 GB/s; speedup = 15.28%
Size (1024, 4904); Mismatches: dX = 6 out of 5021696. Max missmatch idx = 4670210.
reference = 0.0814 (ms); optimized = 0.0721 (ms); bw_opt = 1037.99 GB/s; speedup = 12.90%
Size (4096, 2048); Mismatches: dX = 0 out of 8388608. Max missmatch idx = 0.
reference = 0.0990 (ms); optimized = 0.0770 (ms); bw_opt = 1623.58 GB/s; speedup = 28.54%
Size (1024, 1860); Mismatches: dX = 0 out of 1904640. Max missmatch idx = 0.
reference = 0.0219 (ms); optimized = 0.0174 (ms); bw_opt = 1631.12 GB/s; speedup = 25.75%
Size (1024, 20160); Mismatches: dX = 23 out of 20643840. Max missmatch idx = 20274656.
reference = 0.3054 (ms); optimized = 0.2600 (ms); bw_opt = 1183.08 GB/s; speedup = 17.45%
Size (3072, 256); Mismatches: dX = 0 out of 786432. Max missmatch idx = 0.
reference = 0.0129 (ms); optimized = 0.0127 (ms); bw_opt = 925.71 GB/s; speedup = 1.69%
Size (4096, 128); Mismatches: dX = 3 out of 524288. Max missmatch idx = 451331.
reference = 0.0128 (ms); optimized = 0.0129 (ms); bw_opt = 608.06 GB/s; speedup = -0.74%
Size (512, 128); Mismatches: dX = 0 out of 65536. Max missmatch idx = 0.
reference = 0.0062 (ms); optimized = 0.0061 (ms); bw_opt = 161.25 GB/s; speedup = 2.35%
Size (2048, 64); Mismatches: dX = 0 out of 131072. Max missmatch idx = 0.
reference = 0.0084 (ms); optimized = 0.0086 (ms); bw_opt = 228.70 GB/s; speedup = -2.49%
Size (3072, 2048); Mismatches: dX = 0 out of 6291456. Max missmatch idx = 0.
reference = 0.0770 (ms); optimized = 0.0614 (ms); bw_opt = 1527.43 GB/s; speedup = 25.44%
Size (3200, 104); Mismatches: dX = 0 out of 332800. Max missmatch idx = 0.
reference = 0.0105 (ms); optimized = 0.0113 (ms); bw_opt = 440.93 GB/s; speedup = -6.96%
Size (1152, 384); Mismatches: dX = 0 out of 442368. Max missmatch idx = 0.
reference = 0.0102 (ms); optimized = 0.0084 (ms); bw_opt = 786.48 GB/s; speedup = 21.59%
Size (131072, 64); Mismatches: dX = 12 out of 8388608. Max missmatch idx = 7659094.
reference = 0.2054 (ms); optimized = 0.2873 (ms); bw_opt = 438.49 GB/s; speedup = -28.51%
Size (64, 131072); Mismatches: dX = 0 out of 8388608. Max missmatch idx = 0.
reference = 0.8372 (ms); optimized = 0.3295 (ms); bw_opt = 379.37 GB/s; speedup = 154.09%
Size (131072, 128); Mismatches: dX = 18 out of 16777216. Max missmatch idx = 16158071.
reference = 0.2296 (ms); optimized = 0.3116 (ms); bw_opt = 805.47 GB/s; speedup = -26.31%
Size (128, 131072); Mismatches: dX = 0 out of 16777216. Max missmatch idx = 0.
reference = 0.9297 (ms); optimized = 0.3785 (ms); bw_opt = 660.52 GB/s; speedup = 145.64%
Size (131072, 256); Mismatches: dX = 47 out of 33554432. Max missmatch idx = 33062426.
reference = 0.3003 (ms); optimized = 0.4231 (ms); bw_opt = 1184.07 GB/s; speedup = -29.02%
Size (256, 131072); Mismatches: dX = 0 out of 33554432. Max missmatch idx = 0.
reference = 1.0449 (ms); optimized = 0.4828 (ms); bw_opt = 1035.63 GB/s; speedup = 116.43%
Average speedup = 21.73%
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111021
Approved by: https://github.com/malfet