[TensorExpr] Fix bug when splitting inner reduce axis with tail (#38420)
Summary:
Fixes a bug in the following code:
```
Tensor* c = Reduce("sum", {{10, "m"}}, Sum(), b, {{10, "n"}, {10, "k"}});
// split N loop with tail:
loop.splitWithTail(loop.getLoopStmtsFor(c)[1], 8, &outer, &inner, &tail);
```
When this is expanded there are two ReduceOps:
```
for (int m = 0; m < 10; m++) {
for (int n_outer = 0; n_outer < (10 - 0) / 8; n_outer++) {
for (int n_inner = 0; n_inner < 8; n_inner++) {
for (int k = 0; k < 10; k++) {
sum[m] = ReduceOp(sum, float(0), (sum[m]) + (b[m, n_outer * 8 + n_inner, k]), out_args={m}, reduce_args={n_inner, n_outer, k});
}
}
}
for (int n_tail = 0; n_tail < (10 - 0) % 8; n_tail++) {
for (int k = 0; k < 10; k++) {
sum[m] = ReduceOp(sum, float(0), (sum[m]) + (b[m, n_tail + ((10 - 0) / 8) * 8, k]), out_args={m}, reduce_args={n_tail, k});
}
}
}
```
But each ReduceOp will expand it's initializer, which in this case will overwrite the sum of the split loop:
```
for (int m = 0; m < 10; m++) {
sum[m] = 0.f;
for (int n_inner = 0; n_inner < 8; n_inner++) {
for (int k = 0; k < 10; k++) {
sum[m] = (sum[m]) + (b[(100 * m + k) + 10 * n_inner]);
}
}
sum[m] = 0.f; <------- *HERE*
for (int n_tail = 0; n_tail < 2; n_tail++) {
for (int k = 0; k < 10; k++) {
sum[m] = (sum[m]) + (b[((100 * m + k) + 10 * n_tail) + 80]);
}
}
}
```
The simplest fix is to remove the initializer from the tail loop, which requires adding support for Reductions without an initializer (I did via adding a NoOp Expr rather than handling nullptr). Also moved the ReductionExpander from loopnest.cpp to reduction.h as loopnest is getting a bit heavy.
Added tests for all kinds of splits on a simple 3D reduction to verify no more problems of this type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38420
Differential Revision: D21587583
Pulled By: nickgg
fbshipit-source-id: e0766934481917007119612eb60cc76c3242e44a