[inductor] add decompostition for mm in backward (#120933)
Summary:
1) As a follow up in D53602514. Found a new way to decompose mm in backward. Sum the permuted input and reduce along 0 dim. Some benchmark result P1190140001. 30x speedup
Some explanations on why the original mm decomposition is slow. For mxkxn mm, when m is small and k is large, the stride for lhs is [m,1], hence it need to access memory k times to load all the data. As a result, decomposition will be effective with permute since the stride will be [k,1].
2) add another pattern for large k. benchmark result P1190596489 28x speedup
3) fix the value not found error in ig ctr. f536115499
Test Plan:
pt2 decompose:
{F1462894821}
decompose: f536159404
baseline: f536282578
705k vs 725k 4% for ig ctr
Differential Revision: D54294491
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120933
Approved by: https://github.com/mengluy0125