pytorch
5ff2ad6f - torch._int_mm: fix triton kernel caching (#99283)

Commit
1 year ago
torch._int_mm: fix triton kernel caching (#99283) Summary: A fix to ensure that kernels generated for `torch._int_mm` can be cached. We can remove this hack one eager mode `torch._int_mm` is better supported. Let me know if something more proper is needed instead of the hack. Test plan: ``` // running the script below led to two compilations of triton // int8,int8->int32 kernel before this PR, and only has // one compilation which is reused after this PR import torch import torch.nn as nn x = torch.randint(-128, 127, (32, 32), device='cuda', dtype=torch.int8) y = torch.randint(-128, 127, (32, 32), device='cuda', dtype=torch.int8) class M(nn.Module): def forward(self, x): x = torch._int_mm(x, y) x = x.to(torch.int8) x = torch._int_mm(x, y) return x m = M().cuda().half() m = torch.compile(m, options={"max-autotune": True}) z = m(x) z = m(x) ``` Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/99283 Approved by: https://github.com/nmacchioni, https://github.com/janeyx99
Author
Committer
Parents
Loading