Implement bf16xint16 kernel (#2349)
Summary:
Pull Request resolved: https://github.com/pytorch/benchmark/pull/2349
Modify the bf16xint16 kernel from the previous PR to actually do as intended: load the int16 input, convert it to bf16 inside the kernel, and then do the matmul.
On H100:
```
$ python run_benchmark.py triton -- --op bf16xint16_gemm
x_val bf16xbf16-best_config bf16xbf16-gbps bf16xbf16-latency bf16xbf16-tflops bf16xint16-best_config bf16xint16-gbps bf16xint16-latency bf16xint16-tflops bf16xint16_casted-best_config bf16xint16_casted-gbps bf16xint16_casted-latency bf16xint16_casted-tflops
------------------- -------------------------------------------------------------------------------------------------------------------------------- ---------------- ------------------- ------------------ --------------------------------------------------------------------------------------------------------------------------------- ----------------- -------------------- ------------------- ------------------------------- ------------------------ --------------------------- --------------------------
...
(16384, 1280, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 557.04 0.561899 611.494 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 225.308 1.38921 247.333 524.85 0.596361 576.157
(16384, 8192, 1024) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 527.772 0.576171 477.077 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 263.431 1.15433 238.127 498.965 0.609436 451.037
(16384, 7168, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 165.303 3.13361 614.034 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 65.4821 7.91051 243.239 156.515 3.30957 581.389
(16384, 8192, 3584) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 239.774 1.63995 586.649 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 101.813 3.86212 249.105 225.613 1.74288 552.003
(65536, 1280, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 562.389 2.21223 621.268 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 222.597 5.58917 245.902 552.792 2.25064 610.666
(65536, 8192, 1024) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 539.745 2.2419 490.437 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 270.808 4.46832 246.068 532.576 2.27208 483.922
(65536, 7168, 8192) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 166.428 12.1851 631.639 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 66.2297 30.6199 251.359 163.514 12.4023 620.577
(65536, 8192, 3584) BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 243.299 6.37422 603.727 BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 128, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 107.069 14.4846 265.682 239.25 6.48212 593.678
```
On A100, this crashes due to some bugs/unhandled cases in triton.
imported-using-ghimport
Test Plan: Imported from OSS
Reviewed By: xuzhao9
Differential Revision: D59234866
Pulled By: davidberard98
fbshipit-source-id: 46f0d671ce7bf9315d7ea7551663b03a36da3bc3