[inductor] Adding a way to force fusion of int_mm with mul (#111125)
Summary: When doing quantization int_mm -> mul or int_mm -> mul ->
to(dtype) is an extremely common op pattern which is currently not
handled well by inductor. Ideally, since the output of
int_mm has dtype int32 we'd prefer to only realize a smaller dtype like
bf16 or float16. Currently inductor doesn't have a way to force this, in
many cases the mul gets fused with a bunch of subsequent pointwise
ops from the dequant creating an increase in memory overhead and a general
slowdown compared to the fused version.
Theoretically with better control of/smarter inductor fusion, this could be something we get for free, at which point these changes can be removed.
Test Plan: python test/inductor/test_pattern_matcher.py -k
"int_mm_mul"
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111125
Approved by: https://github.com/jansel, https://github.com/cpuhrsch