Improve native layer norm backward perf (#68238)
Summary:
Benchmarks
At this PR
```
[------------------------------------------------------ ln ------------------------------------------------------]
| fwd, torch.float32 | fwdbwd, torch.float32 | fwd, torch.float16 | fwdbwd, torch.float16
1 threads: -------------------------------------------------------------------------------------------------------
200, 256 | 17.5 | 106.6 | 18.1 | 94.7
1000, 256 | 18.7 | 116.6 | 18.7 | 110.7
6000, 256 | 28.1 | 111.8 | 19.4 | 92.3
6272, 256 | 29.3 | 108.5 | 20.1 | 92.7
200, 512 | 19.3 | 83.8 | 19.1 | 116.3
1000, 512 | 17.9 | 88.0 | 17.9 | 93.0
6000, 512 | 36.9 | 141.2 | 27.4 | 103.3
6272, 512 | 38.2 | 146.5 | 28.1 | 107.9
200, 1024 | 18.1 | 89.5 | 21.1 | 102.7
1000, 1024 | 17.9 | 88.7 | 18.5 | 92.5
6000, 1024 | 77.6 | 277.5 | 40.3 | 148.5
6272, 1024 | 80.7 | 288.1 | 42.0 | 154.0
200, 1536 | 17.9 | 117.3 | 18.1 | 88.1
1000, 1536 | 22.9 | 92.0 | 19.4 | 89.0
6000, 1536 | 123.4 | 436.3 | 61.7 | 228.5
6272, 1536 | 129.1 | 457.3 | 64.3 | 238.5
200, 2048 | 18.0 | 90.5 | 19.1 | 101.6
1000, 2048 | 31.1 | 109.8 | 25.3 | 107.9
6000, 2048 | 174.5 | 589.8 | 87.1 | 310.5
6272, 2048 | 182.2 | 617.0 | 91.2 | 316.7
200, 3072 | 19.8 | 96.4 | 19.4 | 89.3
1000, 3072 | 48.1 | 168.7 | 23.5 | 100.9
6000, 3072 | 267.1 | 930.0 | 134.8 | 519.2
6272, 3072 | 278.2 | 971.2 | 140.7 | 540.2
```
Pre-https://github.com/pytorch/pytorch/issues/67977
```
[------------------------------------------------------- ln -------------------------------------------------------]
| fwd, torch.float32 | fwdbwd, torch.float32 | fwd, torch.float16 | fwdbwd, torch.float16
1 threads: ---------------------------------------------------------------------------------------------------------
200, 256 | 20.9 | 92.6 | 21.3 | 110.1
1000, 256 | 20.3 | 91.8 | 28.1 | 115.6
6000, 256 | 93.0 | 310.7 | 86.3 | 299.8
6272, 256 | 97.3 | 323.5 | 90.0 | 314.1
200, 512 | 20.9 | 110.2 | 21.1 | 95.0
1000, 512 | 24.0 | 102.8 | 22.2 | 95.9
6000, 512 | 121.7 | 367.2 | 105.6 | 337.4
6272, 512 | 127.0 | 382.3 | 111.3 | 352.0
200, 1024 | 21.0 | 131.8 | 20.4 | 93.3
1000, 1024 | 35.5 | 108.7 | 27.7 | 99.4
6000, 1024 | 170.4 | 495.5 | 137.7 | 411.4
6272, 1024 | 177.5 | 517.6 | 143.6 | 428.6
200, 1536 | 21.9 | 97.6 | 20.8 | 92.7
1000, 1536 | 44.3 | 129.7 | 33.9 | 100.1
6000, 1536 | 215.8 | 619.2 | 167.2 | 480.9
6272, 1536 | 225.0 | 646.9 | 174.8 | 505.9
200, 2048 | 21.8 | 100.8 | 20.7 | 96.7
1000, 2048 | 53.7 | 152.4 | 41.4 | 118.3
6000, 2048 | 267.0 | 753.6 | 220.4 | 571.5
6272, 2048 | 278.6 | 785.8 | 211.4 | 589.2
200, 3072 | 20.9 | 103.7 | 21.9 | 104.6
1000, 3072 | 71.4 | 201.1 | 53.1 | 148.3
6000, 3072 | 365.7 | 1040.3 | 262.0 | 731.5
6272, 3072 | 382.0 | 1084.4 | 273.3 | 766.3
```
Benchmarking script
```
import torch
from torch.utils.benchmark import Timer, Compare
results = []
for dtype in (torch.float, torch.half):
for fs in (256, 512, 1024, 1536, 2048, 3072):
for bs in (200, 1000, 6000, 196*32):
ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
gO = torch.rand_like(X)
stmtfwd = "ln(X)"
stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
tfwd = Timer(stmt=stmtfwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwd, {dtype}", globals=globals())
tfwdbwd = Timer(stmt=stmtfwdbwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwdbwd, {dtype}", globals=globals())
for t in (tfwd, tfwdbwd):
results.append(t.blocked_autorange())
print(fs, end='\r')
c = Compare(results)
c.print()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68238
Reviewed By: mruberry
Differential Revision: D32469450
Pulled By: ngimel
fbshipit-source-id: 08fe755c156d3d5c366c966cb808bf0f3e74c050