pytorch
8bf31245 - [TensorExpr] Fix bug when splitting inner reduce axis with tail (#38420)

Commit
4 years ago
[TensorExpr] Fix bug when splitting inner reduce axis with tail (#38420) Summary: Fixes a bug in the following code: ``` Tensor* c = Reduce("sum", {{10, "m"}}, Sum(), b, {{10, "n"}, {10, "k"}}); // split N loop with tail: loop.splitWithTail(loop.getLoopStmtsFor(c)[1], 8, &outer, &inner, &tail); ``` When this is expanded there are two ReduceOps: ``` for (int m = 0; m < 10; m++) { for (int n_outer = 0; n_outer < (10 - 0) / 8; n_outer++) { for (int n_inner = 0; n_inner < 8; n_inner++) { for (int k = 0; k < 10; k++) { sum[m] = ReduceOp(sum, float(0), (sum[m]) + (b[m, n_outer * 8 + n_inner, k]), out_args={m}, reduce_args={n_inner, n_outer, k}); } } } for (int n_tail = 0; n_tail < (10 - 0) % 8; n_tail++) { for (int k = 0; k < 10; k++) { sum[m] = ReduceOp(sum, float(0), (sum[m]) + (b[m, n_tail + ((10 - 0) / 8) * 8, k]), out_args={m}, reduce_args={n_tail, k}); } } } ``` But each ReduceOp will expand it's initializer, which in this case will overwrite the sum of the split loop: ``` for (int m = 0; m < 10; m++) { sum[m] = 0.f; for (int n_inner = 0; n_inner < 8; n_inner++) { for (int k = 0; k < 10; k++) { sum[m] = (sum[m]) + (b[(100 * m + k) + 10 * n_inner]); } } sum[m] = 0.f; <------- *HERE* for (int n_tail = 0; n_tail < 2; n_tail++) { for (int k = 0; k < 10; k++) { sum[m] = (sum[m]) + (b[((100 * m + k) + 10 * n_tail) + 80]); } } } ``` The simplest fix is to remove the initializer from the tail loop, which requires adding support for Reductions without an initializer (I did via adding a NoOp Expr rather than handling nullptr). Also moved the ReductionExpander from loopnest.cpp to reduction.h as loopnest is getting a bit heavy. Added tests for all kinds of splits on a simple 3D reduction to verify no more problems of this type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/38420 Differential Revision: D21587583 Pulled By: nickgg fbshipit-source-id: e0766934481917007119612eb60cc76c3242e44a
Author
Parents
Loading