adding mixed_dtype_mm to torch._inductor (#106443)
Summary: if torch._inductor.config.use_mixed_mm then we can convert
torch.mm(a, b.to(some_dtype)) into a triton kernel where the casting b
is fused into the matmul rather than needing to instantiate the casted b
tensor. If use_mixed_mm is set, this fused kernel will be autotuned
against the default 2 kernel fallback option. If force_mixed_mm then the
fused kernel will always be used, This option is needed for weight-only quantization where we are in
some cases relying on the superior memory characteristics of the fused
kernel rather than the perf numbers (when we can't afford to load memory
with a tensor 4x the size of our quantized one).
Test Plan: python test/inductor/test_pattern_matcher.py -k "mixed_mm"
python test/inductor/test_torchinductor.py -k "mixed_mm"
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106443
Approved by: https://github.com/jansel