pytorch
2f96981e - [inductor] Reduce duplication of reduction combine functions (#99661)

Commit
1 year ago
[inductor] Reduce duplication of reduction combine functions (#99661) Currently reduction bodies are duplicated in several different places. This reduces duplication by `combine_fn` definition used in `_unroll_reduction_fn` and using it in the triton codegen. For cpp this also makes better use of `reduction_combine{,_vec}` by using them to generate the `omp declare reduction` line and the `vec_reduce_all` call. For triton the only change is that that the combine step gets spread over two lines, e.g. instead of: ```python _tmp1 = tl.where(rmask & xmask, triton_helpers.maximum(_tmp1, tmp0), _tmp1) ``` we get ```python tmp2 = triton_helpers.maximum(_tmp1, tmp0) _tmp1 = tl.where(rmask & xmask, tmp2, _tmp1) ``` For cpp the only change is that inplace reduction operations are now written as an out-of-place operation and an assignment, e.g. instead if ```cpp omp_out += omp_in ``` we generate ```cpp omp_out = omp_out + omp_in ``` Which is a purely cosmetic change Pull Request resolved: https://github.com/pytorch/pytorch/pull/99661 Approved by: https://github.com/lezcano, https://github.com/ngimel
Author
Committer
Parents
Loading