[cuda] vectorized gamma and beta loading in vectorized_layer_norm (#107287)
Improves the performance of `vectorized_layer_norm` by vectorizing access to `gamma` and `beta` buffers. This uses 128 bit load instructions which improves memory bandwidth. The speedup is ~3% on average and there are no obvious regressions on any problem sizes.
Used the following code to test:
```python
import torch
from torch.utils.benchmark import Compare, Timer # @manual
l_inputs = [
(32, 32),
(64, 32),
(256, 128),
(512, 1024),
(1024, 2048),
(2048, 2048),
(4096, 16384),
(70000, 64),
(131072, 512),
(1000, 520),
(4005, 4005),
(10000, 1000),
(1024, 10000),
(8192, 4096),
(10000, 10000),
(3072, 10000),
(6144, 10000),
(1024, 20000),
(1024, 20000),
(512, 1536),
(512, 6144),
(512, 10240),
(1000, 1000),
(2000, 2000),
(10240, 10240),
(384, 128),
(2048, 1024),
(267, 513),
(67, 123479),
(1024, 123479),
(2048, 66679),
(200, 256),
(1000, 256),
(6000, 256),
(6272, 256),
(200, 512),
(1000, 512),
(6000, 512),
(6272, 512),
(200, 1024),
(1000, 1024),
(6000, 1024),
(6272, 1024),
(200, 2048),
(1000, 2048),
(6000, 2048),
(6272, 2048),
(200, 3072),
(1000, 3072),
(6000, 3072),
(6272, 3072),
]
def run_model_on_device(fs, X, gO, device_string, numeric_type):
ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
ln.reset_parameters()
X.grad = None
ln.zero_grad(set_to_none=True)
out = ln(X)
out.backward(gO)
return (ln.weight.grad, ln.bias.grad)
def run_correctness_test(eps_weight, eps_bias):
dtype = torch.float
for val in l_inputs:
bs = val[0]
fs = val[1]
mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
X = mean_adjustment * torch.randn(
bs, fs, device="cpu", dtype=torch.float, requires_grad=True
)
X = X.detach().requires_grad_()
gO = torch.rand_like(X)
X_gpu = X.to("cuda")
X_gpu = X_gpu.detach().requires_grad_()
gO_gpu = gO.to("cuda")
gO_gpu = gO_gpu.detach().requires_grad_()
grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")
weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
weight_mismatches = (weight_delta >= eps_weight).nonzero()
weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100
bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
bias_mismatches = (bias_delta >= eps_bias).nonzero()
bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100
print(
"Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
fs, bs, weight_mismatch_pct, bias_mismatch_pct
)
)
# Run the correctness tests
run_correctness_test(0.01, 0.01)
# Run the performance tests. We need to run this at global scope because otherwise
# the `ln` and `gO` objects are likely removed by the JIT compiler
results = []
for dtype in (torch.float, torch.half):
for val in l_inputs:
bs = val[0]
fs = val[1]
ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
gO = torch.rand_like(X)
stmtfwd = "ln(X)"
stmtfwdbwd = (
"X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
)
tfwd = Timer(
stmt=stmtfwd,
label="ln",
sub_label=f"{bs:5}, {fs:5}",
description=f"fwd, {dtype}",
globals=globals(),
)
tfwdbwd = Timer(
stmt=stmtfwdbwd,
label="ln",
sub_label=f"{bs:5}, {fs:5}",
description=f"fwdbwd, {dtype}",
globals=globals(),
)
for t in (tfwd, tfwdbwd):
results.append(t.blocked_autorange())
print(fs, end="\r")
c = Compare(results)
c.print()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107287
Approved by: https://github.com/malfet