Skip Triton templates in MM max autotune with zero-size inputs (#106865)
Summary:
MM max autotune (and friends) crash when one of the inputs is zero-size.
E.g., running this code:
```
@torch.compile()
def fn(x, y):
return torch.mm(x, y)
inps = [torch.rand([0, 30]), torch.rand([30, 40])]
inps = [x.to(device="cuda") for x in inps]
out = fn(*inps)
```
with this command:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 python test.py
```
raises this error (the top of the stack trace omitted for brevity):
```
...
File "/data/users/aakhundov/pytorch/torch/_inductor/kernel/mm.py", line 119, in tuned_mm
return autotune_select_algorithm("mm", choices, [mat1, mat2], layout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/aakhundov/pytorch/torch/_inductor/select_algorithm.py", line 960, in autotune_select_algorithm
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/aakhundov/pytorch/torch/_inductor/select_algorithm.py", line 787, in __call__
timings = self.lookup(
^^^^^^^^^^^^
File "/data/users/aakhundov/pytorch/torch/_inductor/codecache.py", line 267, in lookup
timings[choice] = benchmark(choice)
^^^^^^^^^^^^^^^^^
File "/data/users/aakhundov/pytorch/torch/_inductor/select_algorithm.py", line 774, in autotune
raise ErrorFromChoice(msg, choice, benchmark_fn.debug_str())
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: ErrorFromChoice: Please run `ptxas /tmp/compile-ptx-src-bfb1c6` to confirm that this is a bug in `ptxas`
From choice TritonTemplateCaller(/tmp/torchinductor_aakhundov/z7/cz7n7nn6rdlaelu4pbaaurgmu74ikl6g76lkngwawrevlfxlc6re.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4)
inputs = [
torch.empty_strided((0, 30), (30, 1), dtype=torch.float32, device='cuda'),
torch.empty_strided((30, 40), (40, 1), dtype=torch.float32, device='cuda'),
]
out = torch.empty_strided((0, 40), (40, 1), dtype=torch.float32, device='cuda')
target: aten.mm.default
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.float32, size=[0, s0], stride=[s0, 1]))
))
args[1]: TensorBox(StorageBox(
InputBuffer(name='arg3_1', layout=FixedLayout('cuda', torch.float32, size=[s0, s1], stride=[s1, 1]))
))
```
This PR adds a check to skip Triton templates in the `mm`, `addmm`, `mm_plus_mm` autotuning when the product of the MM problem shape (`m * n * k`) is zero.
Additionally, early exit conditions have been added to the mm and mm_plus_mm Triton templates on `M * N * K == 0`, to prevent issues when autotuning was done on non-zero-size inputs with dynamic shapes, then zero-size inputs are encountered by the compiled model.
Test Plan:
```
$ python test/inductor/test_max_autotune.py -v
...
----------------------------------------------------------------------
Ran 16 tests in 29.569s
OK
```
Reviewers: @eellison
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106865
Approved by: https://github.com/jansel