[ROCm] Replace layer_norm_grad_input_kernel with cuComputeGradInput for ROCm (#87726)
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when fs (=config_m in our benchmark script) is large and bs (=config_n in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of https://github.com/pytorch/pytorch/pull/68238#issue-1051621716 on AMD GPUs.
This PR is to replace layer_norm_grad_input_kernel with the Apex cuComputeGradInput kernel with some ROCm-specific parameter tuning when fs (=config_m) is larger than or equal to `32768` on AMD GPUs. Some of the code changes in LayerNormBackwardKernelImplInternal are from another PR: https://github.com/pytorch/pytorch/pull/87635
We used the same benchmark script in the previous PR and tested the optimized kernel with various input shapes on AMD MI100 GPU.
**At [the previous PR](https://github.com/pytorch/pytorch/pull/87635):**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">
<head>
<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
{mso-displayed-decimal-separator:"\.";
mso-displayed-thousand-separator:"\,";}
@page
{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
margin:.75in .7in .75in .7in;
mso-header-margin:.3in;
mso-footer-margin:.3in;}
tr
{mso-height-source:auto;}
col
{mso-width-source:auto;}
br
{mso-data-placement:same-cell;}
td
{padding-top:1px;
padding-right:1px;
padding-left:1px;
mso-ignore:padding;
color:black;
font-size:11.0pt;
font-weight:400;
font-style:normal;
text-decoration:none;
font-family:Calibri, sans-serif;
mso-font-charset:0;
mso-number-format:General;
text-align:general;
vertical-align:bottom;
border:none;
mso-background-source:auto;
mso-pattern:auto;
mso-protection:locked visible;
white-space:nowrap;
mso-rotate:0;}
.xl65
{color:windowtext;}
-->
</head>
<body link="#0563C1" vlink="#954F72">
M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38589 | 0.92603 | 0.38367 | 1.15148
50176 | 384 | 0.38719 | 0.91579 | 0.37815 | 1.13761
200704 | 192 | 0.99787 | 2.39954 | 0.98996 | 2.54284
802816 | 64 | 3.66525 | 7.96952 | 3.61293 | 7.69946
200 | 256 | 0.06578 | 0.34613 | 0.06966 | 0.35449
1000 | 256 | 0.07837 | 0.37631 | 0.07725 | 0.37758
6000 | 256 | 0.09318 | 0.3788 | 0.09202 | 0.37989
6272 | 256 | 0.08694 | 0.36267 | 0.08703 | 0.3615
200 | 512 | 0.06975 | 0.34506 | 0.06973 | 0.34208
1000 | 512 | 0.07012 | 0.36363 | 0.07307 | 0.36741
6000 | 512 | 0.09725 | 0.36251 | 0.09908 | 0.37078
6272 | 512 | 0.09899 | 0.36519 | 0.10068 | 0.37514
200 | 1024 | 0.07188 | 0.33896 | 0.0712 | 0.34683
1000 | 1024 | 0.07357 | 0.3625 | 0.0734 | 0.3598
6000 | 1024 | 0.12642 | 0.38949 | 0.12973 | 0.5035
6272 | 1024 | 0.12901 | 0.40759 | 0.13609 | 0.51871
200 | 1536 | 0.06998 | 0.34782 | 0.07419 | 0.3514
1000 | 1536 | 0.07987 | 0.37915 | 0.07888 | 0.37264
6000 | 1536 | 0.15401 | 0.47524 | 0.15416 | 0.68609
6272 | 1536 | 0.15286 | 0.48843 | 0.17681 | 0.72997
200 | 2048 | 0.07054 | 0.34791 | 0.07289 | 0.35138
1000 | 2048 | 0.07767 | 0.37954 | 0.08554 | 0.37464
6000 | 2048 | 0.18744 | 0.5811 | 0.25004 | 0.93338
6272 | 2048 | 0.20037 | 0.63398 | 0.26918 | 0.97018
200 | 3072 | 0.07687 | 0.36739 | 0.08917 | 0.37845
1000 | 3072 | 0.09323 | 0.38901 | 0.09739 | 0.39823
6000 | 3072 | 0.24314 | 0.89029 | 0.38093 | 1.30719
6272 | 3072 | 0.26079 | 0.92023 | 0.38352 | 1.51012
128 | 2097152 | 6.17775 | 23.876 | 10.27952 | 30.10848
256 | 1048576 | 4.51855 | 19.47637 | 10.07609 | 29.42678
512 | 524288 | 4.13615 | 18.80888 | 10.07853 | 32.29804
1024 | 262144 | 4.47397 | 17.88388 | 9.50367 | 31.15699
2048 | 131072 | 4.2458 | 16.70852 | 9.17979 | 30.51708
4096 | 65536 | 4.24412 | 16.43098 | 8.97651 | 30.1617
8192 | 32768 | 4.24556 | 16.09038 | 8.77001 | 30.3643
16384 | 16384 | 4.14642 | 15.80355 | 8.82402 | 30.35291
32768 | 8192 | 4.12599 | 15.68897 | 8.82605 | 30.43423
</body>
</html>
----
**At this PR:**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">
<head>
<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
{mso-displayed-decimal-separator:"\.";
mso-displayed-thousand-separator:"\,";}
@page
{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
margin:.75in .7in .75in .7in;
mso-header-margin:.3in;
mso-footer-margin:.3in;}
tr
{mso-height-source:auto;}
col
{mso-width-source:auto;}
br
{mso-data-placement:same-cell;}
td
{padding-top:1px;
padding-right:1px;
padding-left:1px;
mso-ignore:padding;
color:black;
font-size:11.0pt;
font-weight:400;
font-style:normal;
text-decoration:none;
font-family:Calibri, sans-serif;
mso-font-charset:0;
mso-number-format:General;
text-align:general;
vertical-align:bottom;
border:none;
mso-background-source:auto;
mso-pattern:auto;
mso-protection:locked visible;
white-space:nowrap;
mso-rotate:0;}
.xl65
{color:windowtext;}
.xl66
{background:yellow;
mso-pattern:black none;}
-->
</head>
<body link="#0563C1" vlink="#954F72">
M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38667 | 0.84133 | 0.37916 | 1.01222
50176 | 384 | 0.3814 | 0.87266 | 0.37858 | 1.04399
200704 | 192 | 0.99902 | 2.14386 | 0.98973 | 2.33265
802816 | 64 | 3.66578 | 6.85376 | 3.6092 | 7.00331
200 | 256 | 0.06607 | 0.34176 | 0.07009 | 0.34548
1000 | 256 | 0.06947 | 0.36461 | 0.07902 | 0.37851
6000 | 256 | 0.09319 | 0.37432 | 0.09342 | 0.36927
6272 | 256 | 0.09544 | 0.37565 | 0.09476 | 0.37377
200 | 512 | 0.07935 | 0.364 | 0.07891 | 0.36894
1000 | 512 | 0.07676 | 0.37552 | 0.07957 | 0.37564
6000 | 512 | 0.10472 | 0.37504 | 0.1051 | 0.38782
6272 | 512 | 0.1069 | 0.36662 | 0.10062 | 0.38506
200 | 1024 | 0.07793 | 0.36561 | 0.08023 | 0.35019
1000 | 1024 | 0.07426 | 0.36729 | 0.07345 | 0.35851
6000 | 1024 | 0.12729 | 0.39219 | 0.12974 | 0.51526
6272 | 1024 | 0.13622 | 0.41627 | 0.14252 | 0.52926
200 | 1536 | 0.07615 | 0.36621 | 0.0797 | 0.3695
1000 | 1536 | 0.08327 | 0.38174 | 0.07938 | 0.37573
6000 | 1536 | 0.14894 | 0.46197 | 0.15268 | 0.63814
6272 | 1536 | 0.15368 | 0.48818 | 0.16309 | 0.71441
200 | 2048 | 0.06935 | 0.36691 | 0.07258 | 0.35548
1000 | 2048 | 0.07738 | 0.36388 | 0.08036 | 0.36452
6000 | 2048 | 0.18757 | 0.58573 | 0.23701 | 0.92915
6272 | 2048 | 0.1938 | 0.61628 | 0.26475 | 0.96896
200 | 3072 | 0.07884 | 0.3673 | 0.07724 | 0.37869
1000 | 3072 | 0.09342 | 0.38193 | 0.09822 | 0.38646
6000 | 3072 | 0.24452 | 0.86776 | 0.38251 | 1.3036
6272 | 3072 | 0.25971 | 0.91053 | 0.38744 | 1.39039
128 | 2097152 | 6.06752 | 23.26379 | 9.87466 | 29.81851
256 | 1048576 | 4.50336 | 19.4614 | 10.11239 | 29.25554
512 | 524288 | 4.12649 | 18.72831 | 10.054 | 32.26784
1024 | 262144 | 4.40855 | 17.77993 | 9.38856 | 31.18679
2048 | 131072 | 4.18716 | 16.74615 | 9.14487 | 30.24603
4096 | 65536 | 4.17374 | 16.34444 | 8.94894 | 30.0326
8192 | 32768 | 4.19095 | 16.05751 | 8.70358 | 30.14669
16384 | 16384 | 4.15404 | 15.83771 | 8.80042 | 30.5022
32768 | 8192 | 4.12515 | 15.5657 | 8.66138 | 28.87386
</body>
</html>
---
**Performance Improvement (%)**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">
<head>
<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
{mso-displayed-decimal-separator:"\.";
mso-displayed-thousand-separator:"\,";}
@page
{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
margin:.75in .7in .75in .7in;
mso-header-margin:.3in;
mso-footer-margin:.3in;}
tr
{mso-height-source:auto;}
col
{mso-width-source:auto;}
br
{mso-data-placement:same-cell;}
td
{padding-top:1px;
padding-right:1px;
padding-left:1px;
mso-ignore:padding;
color:black;
font-size:11.0pt;
font-weight:400;
font-style:normal;
text-decoration:none;
font-family:Calibri, sans-serif;
mso-font-charset:0;
mso-number-format:General;
text-align:general;
vertical-align:bottom;
border:none;
mso-background-source:auto;
mso-pattern:auto;
mso-protection:locked visible;
white-space:nowrap;
mso-rotate:0;}
.xl65
{color:windowtext;}
.xl66
{mso-number-format:"0\.000";}
-->
</head>
<body link="#0563C1" vlink="#954F72">
M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32
-- | -- | -- | --
50432 | 384 | 9.147 | 12.094
50176 | 384 | 4.710 | 8.230
200704 | 192 | 10.655 | 8.266
802816 | 64 | 14.000 | 9.042
200 | 256 | 1.263 | 2.542
1000 | 256 | 3.109 | -0.246
6000 | 256 | 1.183 | 2.796
6272 | 256 | -3.579 | -3.394
200 | 512 | -5.489 | -7.852
1000 | 512 | -3.270 | -2.240
6000 | 512 | -3.456 | -4.596
6272 | 512 | -0.392 | -2.644
200 | 1024 | -7.862 | -0.969
1000 | 1024 | -1.321 | 0.359
6000 | 1024 | -0.693 | -2.336
6272 | 1024 | -2.130 | -2.034
200 | 1536 | -5.287 | -5.151
1000 | 1536 | -0.683 | -0.829
6000 | 1536 | 2.792 | 6.989
6272 | 1536 | 0.051 | 2.132
200 | 2048 | -5.461 | -1.167
1000 | 2048 | 4.126 | 2.701
6000 | 2048 | -0.797 | 0.453
6272 | 2048 | 2.792 | 0.126
200 | 3072 | 0.024 | -0.063
1000 | 3072 | 1.820 | 2.956
6000 | 3072 | 2.531 | 0.275
6272 | 3072 | 1.054 | 7.929
128 | 2097152 | 2.564 | 0.963
256 | 1048576 | 0.077 | 0.582
512 | 524288 | 0.428 | 0.094
1024 | 262144 | 0.581 | -0.096
2048 | 131072 | -0.225 | 0.888
4096 | 65536 | 0.527 | 0.428
8192 | 32768 | 0.204 | 0.717
16384 | 16384 | -0.216 | -0.492
32768 | 8192 | 0.786 | 5.127
</body>
</html>
CC: @jeffdaily
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87726
Approved by: https://github.com/ngimel