pytorch
f29b9574 - [cuda] vectorized implementation for layer_norm_grad_input_kernel (#111021)

Commit
1 year ago
[cuda] vectorized implementation for layer_norm_grad_input_kernel (#111021) Using vectorized loads/stores makes the `layer_norm_grad_input_kernel` generally faster. This PR accelerates medium and larger problem sizes. ```python def run_model_on_device(fs, X, gO, device_string, numeric_type): ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type) ln.reset_parameters() X.grad = None ln.zero_grad(set_to_none=True) out = ln(X) out.backward(gO) return (ln.weight.grad, ln.bias.grad) def run_correctness_test(eps_weight, eps_bias): dtype = torch.float for val in l_inputs: bs = val[0][0] fs = val[0][1] mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float) X = mean_adjustment * torch.randn( bs, fs, device="cpu", dtype=torch.float, requires_grad=True ) X = X.detach().requires_grad_() gO = torch.rand_like(X) X_gpu = X.to("cuda") X_gpu = X_gpu.detach().requires_grad_() gO_gpu = gO.to("cuda") gO_gpu = gO_gpu.detach().requires_grad_() grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype) grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype) weight_grad_gpu_target = grad_gpu[0].detach().to("cpu") bias_grad_gpu_target = grad_gpu[1].detach().to("cpu") weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target) weight_mismatches = (weight_delta >= eps_weight).nonzero() weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100 bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target) bias_mismatches = (bias_delta >= eps_bias).nonzero() bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100 if weight_mismatch_pct > 0 or bias_mismatch_pct > 0: print( "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format( fs, bs, weight_mismatch_pct, bias_mismatch_pct ) ) # Run the correctness tests run_correctness_test(0.01, 0.01) torch.cuda.synchronize() # Allocate a tensor equal to L2 cache size on A100 GPUs l2_cache_flusher = torch.empty(int(80 * (1024**2)), dtype=torch.float, device="cuda") # Run the performance tests. We need to run this at global scope because otherwise # the `ln` and `gO` objects are likely removed by the JIT compiler results = [] for dtype in (torch.float, torch.half): for val in l_inputs: bs = val[0][0] fs = val[0][1] iterations = val[1] ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype) X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True) gO = torch.rand_like(X) # Try to measure FWD and BWD pass in the same loop l_ev_start_fwd = [torch.cuda.Event(enable_timing=True)] * iterations l_ev_stop_fwd = [torch.cuda.Event(enable_timing=True)] * iterations l_ev_stop_bwd = [torch.cuda.Event(enable_timing=True)] * iterations l_fwd_times = [] l_bwd_times = [] torch.cuda.synchronize() for i in range(iterations): l2_cache_flusher.zero_() torch.cuda._sleep(1_000_000) X.grad = None ln.zero_grad(set_to_none=True) l_ev_start_fwd[i].record() out = ln(X) l_ev_stop_fwd[i].record() out.backward(gO) l_ev_stop_bwd[i].record() torch.cuda.synchronize() l_fwd_times = [] l_bwd_times = [] for i in range(iterations): l_fwd_times.append(l_ev_start_fwd[i].elapsed_time(l_ev_stop_fwd[i])) l_bwd_times.append(l_ev_stop_fwd[i].elapsed_time(l_ev_stop_bwd[i])) print( "({}, {}, {}, fwd_ms, bwd_ms)|{:.3f}|{:.3f}".format( dtype, bs, fs, sum(l_fwd_times) / iterations * 1000, sum(l_bwd_times) / iterations * 1000, ) ) ``` Results in the attached picture: <img width="314" alt="Screenshot 2023-10-16 at 11 08 25 AM" src="https://github.com/pytorch/pytorch/assets/23515689/ce571fc5-c84e-47eb-95f6-9faa44042cc1"> I also isolated the previous implementation and the vectorized one into a native CUDA program and the speedup is confirmed. **Average speedup = 21.73%** ``` Size (2048, 2048); Mismatches: dX = 0 out of 4194304. Max missmatch idx = 0. [16/1529] reference = 0.0560 (ms); optimized = 0.0435 (ms); bw_opt = 1437.54 GB/s; speedup = 28.78% Size (4096, 512); Mismatches: dX = 0 out of 2097152. Max missmatch idx = 0. reference = 0.0220 (ms); optimized = 0.0174 (ms); bw_opt = 1797.26 GB/s; speedup = 26.44% Size (1024, 512); Mismatches: dX = 0 out of 524288. Max missmatch idx = 0. reference = 0.0101 (ms); optimized = 0.0082 (ms); bw_opt = 953.49 GB/s; speedup = 22.97% Size (1024, 256); Mismatches: dX = 1 out of 262144. Max missmatch idx = 22411. reference = 0.0082 (ms); optimized = 0.0075 (ms); bw_opt = 521.14 GB/s; speedup = 9.21% Size (1024, 1024); Mismatches: dX = 0 out of 1048576. Max missmatch idx = 0. reference = 0.0137 (ms); optimized = 0.0108 (ms); bw_opt = 1447.42 GB/s; speedup = 26.93% Size (2048, 512); Mismatches: dX = 0 out of 1048576. Max missmatch idx = 0. reference = 0.0141 (ms); optimized = 0.0116 (ms); bw_opt = 1349.79 GB/s; speedup = 21.81% Size (2048, 256); Mismatches: dX = 0 out of 524288. Max missmatch idx = 0. reference = 0.0108 (ms); optimized = 0.0102 (ms); bw_opt = 768.90 GB/s; speedup = 6.09% Size (1024, 128); Mismatches: dX = 1 out of 131072. Max missmatch idx = 9165. reference = 0.0070 (ms); optimized = 0.0068 (ms); bw_opt = 288.56 GB/s; speedup = 2.81% Size (1024, 2048); Mismatches: dX = 0 out of 2097152. Max missmatch idx = 0. reference = 0.0223 (ms); optimized = 0.0164 (ms); bw_opt = 1905.58 GB/s; speedup = 35.90% Size (1024, 768); Mismatches: dX = 3 out of 786432. Max missmatch idx = 507105. reference = 0.0113 (ms); optimized = 0.0101 (ms); bw_opt = 1160.00 GB/s; speedup = 11.79% Size (2048, 128); Mismatches: dX = 0 out of 262144. Max missmatch idx = 0. reference = 0.0097 (ms); optimized = 0.0089 (ms); bw_opt = 440.97 GB/s; speedup = 9.12% Size (2048, 1024); Mismatches: dX = 0 out of 2097152. Max missmatch idx = 0. reference = 0.0204 (ms); optimized = 0.0166 (ms); bw_opt = 1881.43 GB/s; speedup = 22.81% Size (4096, 256); Mismatches: dX = 1 out of 1048576. Max missmatch idx = 601965. reference = 0.0156 (ms); optimized = 0.0154 (ms); bw_opt = 1016.47 GB/s; speedup = 1.24% Size (4096, 1024); Mismatches: dX = 0 out of 4194304. Max missmatch idx = 0. reference = 0.0411 (ms); optimized = 0.0417 (ms); bw_opt = 1499.55 GB/s; speedup = -1.43% Size (4096, 4096); Mismatches: dX = 0 out of 16777216. Max missmatch idx = 0. reference = 0.2323 (ms); optimized = 0.2077 (ms); bw_opt = 1203.75 GB/s; speedup = 11.83% Size (1024, 4096); Mismatches: dX = 0 out of 4194304. Max missmatch idx = 0. reference = 0.0659 (ms); optimized = 0.0570 (ms); bw_opt = 1096.51 GB/s; speedup = 15.60% Size (1024, 3072); Mismatches: dX = 0 out of 3145728. Max missmatch idx = 0. reference = 0.0425 (ms); optimized = 0.0299 (ms); bw_opt = 1568.10 GB/s; speedup = 42.11% Size (1024, 2464); Mismatches: dX = 8 out of 2523136. Max missmatch idx = 2087476. reference = 0.0292 (ms); optimized = 0.0230 (ms); bw_opt = 1636.18 GB/s; speedup = 27.07% Size (1024, 800); Mismatches: dX = 1 out of 819200. Max missmatch idx = 652342. reference = 0.0114 (ms); optimized = 0.0104 (ms); bw_opt = 1175.05 GB/s; speedup = 9.63% Size (1024, 6144); Mismatches: dX = 0 out of 6291456. Max missmatch idx = 0. reference = 0.0973 (ms); optimized = 0.0844 (ms); bw_opt = 1110.87 GB/s; speedup = 15.28% Size (1024, 4904); Mismatches: dX = 6 out of 5021696. Max missmatch idx = 4670210. reference = 0.0814 (ms); optimized = 0.0721 (ms); bw_opt = 1037.99 GB/s; speedup = 12.90% Size (4096, 2048); Mismatches: dX = 0 out of 8388608. Max missmatch idx = 0. reference = 0.0990 (ms); optimized = 0.0770 (ms); bw_opt = 1623.58 GB/s; speedup = 28.54% Size (1024, 1860); Mismatches: dX = 0 out of 1904640. Max missmatch idx = 0. reference = 0.0219 (ms); optimized = 0.0174 (ms); bw_opt = 1631.12 GB/s; speedup = 25.75% Size (1024, 20160); Mismatches: dX = 23 out of 20643840. Max missmatch idx = 20274656. reference = 0.3054 (ms); optimized = 0.2600 (ms); bw_opt = 1183.08 GB/s; speedup = 17.45% Size (3072, 256); Mismatches: dX = 0 out of 786432. Max missmatch idx = 0. reference = 0.0129 (ms); optimized = 0.0127 (ms); bw_opt = 925.71 GB/s; speedup = 1.69% Size (4096, 128); Mismatches: dX = 3 out of 524288. Max missmatch idx = 451331. reference = 0.0128 (ms); optimized = 0.0129 (ms); bw_opt = 608.06 GB/s; speedup = -0.74% Size (512, 128); Mismatches: dX = 0 out of 65536. Max missmatch idx = 0. reference = 0.0062 (ms); optimized = 0.0061 (ms); bw_opt = 161.25 GB/s; speedup = 2.35% Size (2048, 64); Mismatches: dX = 0 out of 131072. Max missmatch idx = 0. reference = 0.0084 (ms); optimized = 0.0086 (ms); bw_opt = 228.70 GB/s; speedup = -2.49% Size (3072, 2048); Mismatches: dX = 0 out of 6291456. Max missmatch idx = 0. reference = 0.0770 (ms); optimized = 0.0614 (ms); bw_opt = 1527.43 GB/s; speedup = 25.44% Size (3200, 104); Mismatches: dX = 0 out of 332800. Max missmatch idx = 0. reference = 0.0105 (ms); optimized = 0.0113 (ms); bw_opt = 440.93 GB/s; speedup = -6.96% Size (1152, 384); Mismatches: dX = 0 out of 442368. Max missmatch idx = 0. reference = 0.0102 (ms); optimized = 0.0084 (ms); bw_opt = 786.48 GB/s; speedup = 21.59% Size (131072, 64); Mismatches: dX = 12 out of 8388608. Max missmatch idx = 7659094. reference = 0.2054 (ms); optimized = 0.2873 (ms); bw_opt = 438.49 GB/s; speedup = -28.51% Size (64, 131072); Mismatches: dX = 0 out of 8388608. Max missmatch idx = 0. reference = 0.8372 (ms); optimized = 0.3295 (ms); bw_opt = 379.37 GB/s; speedup = 154.09% Size (131072, 128); Mismatches: dX = 18 out of 16777216. Max missmatch idx = 16158071. reference = 0.2296 (ms); optimized = 0.3116 (ms); bw_opt = 805.47 GB/s; speedup = -26.31% Size (128, 131072); Mismatches: dX = 0 out of 16777216. Max missmatch idx = 0. reference = 0.9297 (ms); optimized = 0.3785 (ms); bw_opt = 660.52 GB/s; speedup = 145.64% Size (131072, 256); Mismatches: dX = 47 out of 33554432. Max missmatch idx = 33062426. reference = 0.3003 (ms); optimized = 0.4231 (ms); bw_opt = 1184.07 GB/s; speedup = -29.02% Size (256, 131072); Mismatches: dX = 0 out of 33554432. Max missmatch idx = 0. reference = 1.0449 (ms); optimized = 0.4828 (ms); bw_opt = 1035.63 GB/s; speedup = 116.43% Average speedup = 21.73% ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/111021 Approved by: https://github.com/malfet
Committer
Parents
Loading