[pytorch] Layer norm backward speed gain with warp shuffles (#87445)
Test Plan:
```
Times below are Forward + Backward on A100
Size FP32. Gain. FP16. Gain
256, 256 101.30 9% 103.9 6%
512, 256 110.10 -4% 102.9 10%
1024, 256 104.30 7% 102.4 6%
2048, 256 107.60 4% 109.7 0%
4096, 256 116.70 8% 109.1 0%
6144, 256 106.10 7% 112.8 2%
8192, 256 106.10 1% 109.7 2%
256, 512 102.10 3% 108.5 1%
512, 512 101.50 40% 105.9 4%
1024, 512 109.70 20% 109.2 -1%
2048, 512 107.40 24% 107.2 1%
4096, 512 108.00 6% 110.6 -3%
6144, 512 103.90 13% 105.8 7%
8192, 512 138.70 14% 105.6 7%
256, 1024 106.20 1% 102.9 6%
512, 1024 104.50 4% 104.2 3%
1024, 1024 126.90 -15% 103.9 10%
2048, 1024 127.40 -15% 102.2 6%
4096, 1024 117.70 6% 102.8 21%
6144, 1024 165.30 11% 112.2 12%
8192, 1024 211.90 11% 144.8 13%
256, 1536 102.80 11% 103.1 6%
512, 1536 103.30 9% 102.9 18%
1024, 1536 111.00 -2% 117.2 7%
2048, 1536 102.30 12% 132.1 -4%
4096, 1536 165.50 5% 112.9 18%
6144, 1536 236.60 5% 145.7 12%
8192, 1536 307.80 5% 186.1 11%
256, 2048 110.60 -1% 103.8 7%
512, 2048 105.20 3% 105.6 1%
1024, 2048 106.70 3% 114.8 3%
2048, 2048 124.90 5% 109.7 0%
4096, 2048 231.40 4% 129.9 10%
6144, 2048 332.80 4% 182.5 11%
8192, 2048 434.60 4% 235.2 11%
256, 3072 111.60 8% 110.8 1%
512, 3072 106.80 1% 104.6 10%
1024, 3072 104.90 3% 109.9 4%
2048, 3072 193.80 0% 106.2 10%
4096, 3072 364.50 0% 187.8 5%
6144, 3072 538.30 0% 267 5%
8192, 3072 718.00 -1% 346.7 6%
256, 4096 103.60 4% 110.2 -1%
512, 4096 131.40 -11% 117 -7%
1024, 4096 135.80 1% 104.8 7%
2048, 4096 268.20 1% 149.4 10%
4096, 4096 520.70 1% 268.5 9%
6144, 4096 786.30 0% 389.8 9%
8192, 4096 1043.50 0% 509 10%
```
Used the following script from ngimel:
```
import torch
from torch.utils.benchmark import Compare, Timer
results = []
for dtype in (torch.float, torch.half):
for fs in (256, 512, 1024, 1536, 2048, 3072, 4096):
for bs in (256, 512, 1024, 2048, 4096, 6144, 8192):
ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
gO = torch.rand_like(X)
stmtfwd = "ln(X)"
stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
tfwd = Timer(
stmt=stmtfwd,
label="ln",
sub_label=f"{bs:5}, {fs:5}",
description=f"fwd, {dtype}",
globals=globals(),
)
tfwdbwd = Timer(
stmt=stmtfwdbwd,
label="ln",
sub_label=f"{bs:5}, {fs:5}",
description=f"fwdbwd, {dtype}",
globals=globals(),
)
for t in (tfwd, tfwdbwd):
results.append(t.blocked_autorange())
print(fs, end="\r")
c = Compare(results)
c.print()
```
Differential Revision: D40567574
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87445
Approved by: https://github.com/ngimel