pytorch
f4297576 - [inductor] Adding a way to force fusion of int_mm with mul (#111125)

Commit
1 year ago
[inductor] Adding a way to force fusion of int_mm with mul (#111125) Summary: When doing quantization int_mm -> mul or int_mm -> mul -> to(dtype) is an extremely common op pattern which is currently not handled well by inductor. Ideally, since the output of int_mm has dtype int32 we'd prefer to only realize a smaller dtype like bf16 or float16. Currently inductor doesn't have a way to force this, in many cases the mul gets fused with a bunch of subsequent pointwise ops from the dequant creating an increase in memory overhead and a general slowdown compared to the fused version. Theoretically with better control of/smarter inductor fusion, this could be something we get for free, at which point these changes can be removed. Test Plan: python test/inductor/test_pattern_matcher.py -k "int_mm_mul" Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/111125 Approved by: https://github.com/jansel, https://github.com/cpuhrsch
Author
Committer
Parents
Loading