torch._int_mm: fix triton kernel caching (#99283)
Summary:
A fix to ensure that kernels generated for `torch._int_mm` can be cached. We can remove this hack one eager mode `torch._int_mm` is better supported.
Let me know if something more proper is needed instead of the hack.
Test plan:
```
// running the script below led to two compilations of triton
// int8,int8->int32 kernel before this PR, and only has
// one compilation which is reused after this PR
import torch
import torch.nn as nn
x = torch.randint(-128, 127, (32, 32), device='cuda', dtype=torch.int8)
y = torch.randint(-128, 127, (32, 32), device='cuda', dtype=torch.int8)
class M(nn.Module):
def forward(self, x):
x = torch._int_mm(x, y)
x = x.to(torch.int8)
x = torch._int_mm(x, y)
return x
m = M().cuda().half()
m = torch.compile(m, options={"max-autotune": True})
z = m(x)
z = m(x)
```
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99283
Approved by: https://github.com/nmacchioni, https://github.com/janeyx99