pytorch
b6f28334 - [pytorch] Layer norm backward speed gain with warp shuffles (#87445)

Commit
2 years ago
[pytorch] Layer norm backward speed gain with warp shuffles (#87445) Test Plan: ``` Times below are Forward + Backward on A100 Size FP32. Gain. FP16. Gain 256, 256 101.30 9% 103.9 6% 512, 256 110.10 -4% 102.9 10% 1024, 256 104.30 7% 102.4 6% 2048, 256 107.60 4% 109.7 0% 4096, 256 116.70 8% 109.1 0% 6144, 256 106.10 7% 112.8 2% 8192, 256 106.10 1% 109.7 2% 256, 512 102.10 3% 108.5 1% 512, 512 101.50 40% 105.9 4% 1024, 512 109.70 20% 109.2 -1% 2048, 512 107.40 24% 107.2 1% 4096, 512 108.00 6% 110.6 -3% 6144, 512 103.90 13% 105.8 7% 8192, 512 138.70 14% 105.6 7% 256, 1024 106.20 1% 102.9 6% 512, 1024 104.50 4% 104.2 3% 1024, 1024 126.90 -15% 103.9 10% 2048, 1024 127.40 -15% 102.2 6% 4096, 1024 117.70 6% 102.8 21% 6144, 1024 165.30 11% 112.2 12% 8192, 1024 211.90 11% 144.8 13% 256, 1536 102.80 11% 103.1 6% 512, 1536 103.30 9% 102.9 18% 1024, 1536 111.00 -2% 117.2 7% 2048, 1536 102.30 12% 132.1 -4% 4096, 1536 165.50 5% 112.9 18% 6144, 1536 236.60 5% 145.7 12% 8192, 1536 307.80 5% 186.1 11% 256, 2048 110.60 -1% 103.8 7% 512, 2048 105.20 3% 105.6 1% 1024, 2048 106.70 3% 114.8 3% 2048, 2048 124.90 5% 109.7 0% 4096, 2048 231.40 4% 129.9 10% 6144, 2048 332.80 4% 182.5 11% 8192, 2048 434.60 4% 235.2 11% 256, 3072 111.60 8% 110.8 1% 512, 3072 106.80 1% 104.6 10% 1024, 3072 104.90 3% 109.9 4% 2048, 3072 193.80 0% 106.2 10% 4096, 3072 364.50 0% 187.8 5% 6144, 3072 538.30 0% 267 5% 8192, 3072 718.00 -1% 346.7 6% 256, 4096 103.60 4% 110.2 -1% 512, 4096 131.40 -11% 117 -7% 1024, 4096 135.80 1% 104.8 7% 2048, 4096 268.20 1% 149.4 10% 4096, 4096 520.70 1% 268.5 9% 6144, 4096 786.30 0% 389.8 9% 8192, 4096 1043.50 0% 509 10% ``` Used the following script from ngimel: ``` import torch from torch.utils.benchmark import Compare, Timer results = [] for dtype in (torch.float, torch.half): for fs in (256, 512, 1024, 1536, 2048, 3072, 4096): for bs in (256, 512, 1024, 2048, 4096, 6144, 8192): ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype) X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True) gO = torch.rand_like(X) stmtfwd = "ln(X)" stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)" tfwd = Timer( stmt=stmtfwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwd, {dtype}", globals=globals(), ) tfwdbwd = Timer( stmt=stmtfwdbwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwdbwd, {dtype}", globals=globals(), ) for t in (tfwd, tfwdbwd): results.append(t.blocked_autorange()) print(fs, end="\r") c = Compare(results) c.print() ``` Differential Revision: D40567574 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87445 Approved by: https://github.com/ngimel
Committer
Parents
Loading