[ROCm] Optimize layer norm backward kernel for ROCm (#87635)
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs` (=`config_m` in our benchmark script) is large and `bs` (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR #68238](https://github.com/pytorch/pytorch/pull/68238#issue-1051621716) on AMD GPUs.
This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs` (=`config_m`) is larger than 512 on AMD GPUs.
There are a few PRs for LayerNorm kernel:
- https://github.com/pytorch/pytorch/pull/26201
- https://github.com/pytorch/pytorch/pull/27634
- https://github.com/pytorch/pytorch/pull/68238
Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.
---
**Current**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">
<head>
<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
{mso-displayed-decimal-separator:"\.";
mso-displayed-thousand-separator:"\,";}
@page
{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
margin:.75in .7in .75in .7in;
mso-header-margin:.3in;
mso-footer-margin:.3in;}
tr
{mso-height-source:auto;}
col
{mso-width-source:auto;}
br
{mso-data-placement:same-cell;}
td
{padding-top:1px;
padding-right:1px;
padding-left:1px;
mso-ignore:padding;
color:black;
font-size:11.0pt;
font-weight:400;
font-style:normal;
text-decoration:none;
font-family:Calibri, sans-serif;
mso-font-charset:0;
mso-number-format:General;
text-align:general;
vertical-align:bottom;
border:none;
mso-background-source:auto;
mso-pattern:auto;
mso-protection:locked visible;
white-space:nowrap;
mso-rotate:0;}
-->
</head>
<body link="#0563C1" vlink="#954F72">
M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892
50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886
200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827
802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946
200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349
1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753
6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429
6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245
200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878
1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751
6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313
6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982
200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007
1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991
6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504
6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133
200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015
1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778
6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987
6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025
200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655
1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685
6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635
6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141
200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034
1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433
6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462
6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524
128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092
256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371
512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902
1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192
2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191
4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751
8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646
16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408
32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271
</body>
</html>
---------
**At this PR**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">
<head>
<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
{mso-displayed-decimal-separator:"\.";
mso-displayed-thousand-separator:"\,";}
@page
{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
margin:.75in .7in .75in .7in;
mso-header-margin:.3in;
mso-footer-margin:.3in;}
tr
{mso-height-source:auto;}
col
{mso-width-source:auto;}
br
{mso-data-placement:same-cell;}
td
{padding-top:1px;
padding-right:1px;
padding-left:1px;
mso-ignore:padding;
color:black;
font-size:11.0pt;
font-weight:400;
font-style:normal;
text-decoration:none;
font-family:Calibri, sans-serif;
mso-font-charset:0;
mso-number-format:General;
text-align:general;
vertical-align:bottom;
border:none;
mso-background-source:auto;
mso-pattern:auto;
mso-protection:locked visible;
white-space:nowrap;
mso-rotate:0;}
.xl63
{color:windowtext;}
-->
</head>
<body link="#0563C1" vlink="#954F72">
M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283
50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595
200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579
802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404
200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602
1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742
6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279
6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426
200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018
1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206
6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572
6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635
200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216
1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936
6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273
6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545
200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545
1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204
6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119
6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208
200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859
1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583
6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796
6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055
200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695
1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633
6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289
6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694
128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699
256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936
512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083
1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117
2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845
4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392
8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296
16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113
32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514
</body>
</html>
---------
**Performance Improvement (%)**
<html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">
<head>
<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>
<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->
<div style='direction:ltr'>
M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32
-- | -- | -- | --
50432 | 384 | 32.178 | 22.049
50176 | 384 | 29.231 | 19.536
200704 | 192 | 44.188 | 43.962
802816 | 64 | 52.119 | 54.100
200 | 256 | -5.750 | -0.206
1000 | 256 | 0.031 | -0.797
6000 | 256 | 3.566 | 5.621
6272 | 256 | 3.865 | 4.836
200 | 512 | -1.615 | -1.010
1000 | 512 | -1.270 | 0.208
6000 | 512 | 3.534 | 5.581
6272 | 512 | 7.905 | 7.483
200 | 1024 | -2.883 | 0.254
1000 | 1024 | -0.767 | 0.493
6000 | 1024 | 0.237 | -2.381
6272 | 1024 | 3.840 | -1.707
200 | 1536 | -0.127 | -1.340
1000 | 1536 | -0.711 | -0.992
6000 | 1536 | -0.209 | -4.728
6272 | 1536 | 0.508 | -0.846
200 | 2048 | -1.262 | -1.176
1000 | 2048 | -0.358 | 0.312
6000 | 2048 | 8.350 | 6.487
6272 | 2048 | 1.588 | 5.713
200 | 3072 | 0.223 | -0.848
1000 | 3072 | -0.773 | -5.743
6000 | 3072 | 3.570 | -3.783
6272 | 3072 | 4.962 | -4.092
128 | 2097152 | -4.266 | 0.348
256 | 1048576 | 0.397 | 0.185
512 | 524288 | 17.325 | 16.605
1024 | 262144 | 23.070 | 19.195
2048 | 131072 | 27.469 | 24.605
4096 | 65536 | 32.023 | 27.465
8192 | 32768 | 24.459 | 28.274
16384 | 16384 | 21.439 | 9.514
32768 | 8192 | 6.818 | 0.491
</div>
<!--EndFragment-->
</body>
</html>
---------
**Benchmark script of this PR**
```
# Ref:
# 1. https://github.com/pytorch/pytorch/pull/26201
# 2. https://github.com/pytorch/pytorch/pull/68238
from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit
number_runs = 1000 # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
layer_norm_cuda(input_cuda); torch.cuda.synchronize()
def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()
def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
input_cuda.grad = None
layer_norm_cuda.zero_grad(set_to_none=True)
out = layer_norm_cuda(input_cuda)
out.backward(gO)
torch.cuda.synchronize()
def benchmark(config_m, config_n):
print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
if len(config_m) != len(config_n):
print("Please make sure the lengths of config_m and config_m are the same.")
for i in range(len(config_m)):
normalized_shape = config_n[i]
results = [config_m[i], config_n[i]]
for dtype in (torch.half, torch.float):
if dtype == torch.half:
layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
else:
layer_norm_cuda = LayerNorm(normalized_shape).cuda()
input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)
# print("cuda forward:")
result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
results.append(result_fwd / number_runs * 1000)
gO = torch.rand_like(input_cuda)
result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
results.append(result_fwdbwd / number_runs * 1000)
print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))
print("Times are in microseconds (us).")
# CVT
config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]
# https://github.com/pytorch/pytorch/pull/68238#issue-1051621716
config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]
# https://github.com/pytorch/pytorch/pull/27634
config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]
config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634
benchmark(config_m, config_n)
```
CC: @jeffdaily
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87635
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang