[inductor][Gemm] Autotune with matrix_instr_nonkdim for AMDGPU (#120742)
Relanding https://github.com/pytorch/pytorch/pull/120639 + a fix to drop `matrix_instr_nonkdim` that does not align with `BLOCK_M` or `BLOCK_N`
Matrix multiplication with Triton is usually in a tiled way. For a large tile size that single hardware instruction cannot handle, i.e, a 128*128 size matmul, it has to be broken down to a sequence of smaller hardware mma instructions. On AMDGPU, matrix_instr_nonkdim controls the shape of the mma instructions, of which the default value is 0 in Triton. This means by default Triton will decompose a large tiled matmul operation into a sequence of 32x32x8 mma instructions. There are other mma instructions available, such as 16x16x16 which requires matrix_instr_nonkdim=16. This change enables tuning the value for Gemm which seems to improve its performance by 20% - 2x.
Before:
```
AUTOTUNE mm(1024x1024, 1024x1024)
ExternKernelCaller(extern_kernels.mm) 0.0410 ms 100.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 0.0487 ms 84.2%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 0.0544 ms 75.4%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 0.0633 ms 64.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 0.0687 ms 59.7%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 0.0716 ms 57.3%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 0.0748 ms 54.9%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 0.0788 ms 52.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=2 0.1014 ms 40.5%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 0.1069 ms 38.4%
SingleProcess AUTOTUNE takes 8.1153 seconds
```
After:
```
AUTOTUNE mm(1024x1024, 1024x1024)
ExternKernelCaller(extern_kernels.mm) 0.0417 ms 100.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 0.0470 ms 88.7%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 0.0488 ms 85.4%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=32, num_stages=0, num_warps=4 0.0490 ms 85.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 0.0525 ms 79.5%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=32, num_stages=0, num_warps=4 0.0553 ms 75.4%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 0.0574 ms 72.7%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=32, num_stages=0, num_warps=8 0.0634 ms 65.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=2 0.0655 ms 63.7%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 0.0681 ms 61.2%
SingleProcess AUTOTUNE takes 11.4076 seconds
```
Before:
```
AUTOTUNE mm(2048x2048, 2048x2048)
ExternKernelCaller(extern_kernels.mm) 0.2094 ms 100.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 0.2452 ms 85.4%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 0.2763 ms 75.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 0.2836 ms 73.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 0.2854 ms 73.4%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 0.2951 ms 71.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 0.2970 ms 70.5%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 0.4184 ms 50.1%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 0.5097 ms 41.1%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=2 0.5570 ms 37.6%
SingleProcess AUTOTUNE takes 3.4052 seconds
```
After:
```
AUTOTUNE mm(2048x2048, 2048x2048)
ExternKernelCaller(extern_kernels.mm) 0.2117 ms 100.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=32, num_stages=0, num_warps=8 0.2429 ms 87.2%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 0.2485 ms 85.2%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 0.2526 ms 83.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 0.2537 ms 83.5%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 0.2554 ms 82.9%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 0.2623 ms 80.7%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 0.2695 ms 78.5%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=32, num_stages=0, num_warps=8 0.2758 ms 76.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=32, num_stages=0, num_warps=4 0.2792 ms 75.8%
SingleProcess AUTOTUNE takes 11.3538 seconds
```
Before:
```
AUTOTUNE mm(4096x4096, 4096x4096)
ExternKernelCaller(extern_kernels.mm) 1.5901 ms 100.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 1.9380 ms 82.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 1.9943 ms 79.7%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 2.0640 ms 77.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 2.0941 ms 75.9%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 2.1272 ms 74.7%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 2.1554 ms 73.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 2.2931 ms 69.3%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 3.7016 ms 43.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=2 4.6021 ms 34.6%
SingleProcess AUTOTUNE takes 9.0523 seconds
```
After:
```
AUTOTUNE mm(4096x4096, 4096x4096)
ExternKernelCaller(extern_kernels.mm) 1.5862 ms 100.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 1.6924 ms 93.7%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 1.7616 ms 90.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 1.8159 ms 87.4%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 1.9340 ms 82.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 1.9352 ms 82.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 2.0378 ms 77.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 2.0983 ms 75.6%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 2.1138 ms 75.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 2.1657 ms 73.2%
SingleProcess AUTOTUNE takes 8.2225 seconds
```
Before:
```
AUTOTUNE mm(8192x8192, 8192x8192)
ExternKernelCaller(extern_kernels.mm) 12.0134 ms 100.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 14.8082 ms 81.1%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 15.4242 ms 77.9%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 16.6869 ms 72.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 16.7751 ms 71.6%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 17.0145 ms 70.6%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 17.1363 ms 70.1%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=4 18.2159 ms 66.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=8 29.4726 ms 40.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=0, num_warps=2 37.9039 ms 31.7%
SingleProcess AUTOTUNE takes 11.0074 seconds
```
After:
```
AUTOTUNE mm(8192x8192, 8192x8192)
ExternKernelCaller(extern_kernels.mm) 11.9554 ms 100.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 12.9953 ms 92.0%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 13.7726 ms 86.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 13.9647 ms 85.6%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=32, num_stages=0, num_warps=8 14.9728 ms 79.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=8 15.3729 ms 77.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 15.3955 ms 77.7%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=32, num_stages=0, num_warps=4 15.5647 ms 76.8%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 16.0037 ms 74.7%
ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, matrix_instr_nonkdim=16, num_stages=0, num_warps=4 16.7432 ms 71.4%
SingleProcess AUTOTUNE takes 14.9839 seconds
```
Reviewed By: xw285cornell, nmacchioni
Differential Revision: D54203170
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120742
Approved by: https://github.com/xw285cornell