adding fused uint4x2_mixed_mm to inductor (#106516)
Summary: this is needed for int4 weight-only quantization, we're
matching on the specific unpack operation that unpacks the uint4x2 into
int4's so we can have a fused kernel for it. note, even if the user
isn't specifically doing this, the two operations are mathematically
equilvanet so it won't cause issues (for some reason int8 bitwise logic
in triton and pytorch doesn't match so that's the only exception). Ideally
at some point full prologue fusion for the mm arguments would be able to
handle this chain but until then, this type of kernel is needed.
Test Plan:
python test/inductor/test_pattern_matcher.py -k "uint4x2"
print test/inductor/test_torchinductor.py -k "uint4x2"
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106516
Approved by: https://github.com/jansel