[TensorExpr] add support for Reduction Ops (#35866)
Summary:
Second attempt at the reduction frontend for the TensorExpr compiler. Has two APIs, a simple version for common reduction types and a customizable Reducer fronted which allows specifying initializer, reduction interaction via lambda and body via lambda.
Simple API looks like so:
```
Buffer b(BufHandle("b", {10}), kInt);
Tensor* c = Reduce("sum", {}, Sum(b), {{10, "m"}});
```
An example of specializing a Sum to do Matmul:
```
Buffer tA(BufHandle("tA", {M, K}), kFloat);
Buffer tB(BufHandle("tB", {K, N}), kFloat);
Sum matmul([&](ParameterList& v) {
ExprHandle m = v[0];
ExprHandle n = v[1];
ExprHandle k = v[2];
return tA(m, k) * tB(k, n);
});
Tensor* mm = Reduce("mm", {{M, "m"}, {N, "n"}}, matmul, {{K, "k"}});
```
A fully specialized Reduction:
```
VarHandle searchValue("searchValue", kInt);
Buffer b(BufHandle("b", {4, 10}), kInt);
Reducer anyEqSV(
ExprHandle(0),
[](ExprHandle a, ExprHandle b) {
return CompareSelect::make(a, 1, 1, b, kEQ);
},
[&](ParameterList& v) {
return CompareSelect::make(b.call(v), searchValue, kEQ);
});
Tensor* any = Reduce("anyEqual", {{4, "i"}}, anyEqSV, {{10, "j"}});
```
---
Until lowering, Reductions are held in a compound form for easier optimization:
```
VarHandle m("m", kInt);
Buffer b(BufHandle("b", {2, 3, m}), kFloat);
Tensor* c = Reduce("sum", {{2, "l"}, {3, "n"}}, Sum(b), {{m, "m"}});
LoopNest loop({c});
std::cout << *loop.root_stmt() << "\n";
```
```
for (int l = 0; l < 2; l++) {
for (int n = 0; n < 3; n++) {
for (int m = 0; m < m_1; m++) {
sum[l, n] = ReduceOp(sum[l, n] = float(0);, (sum[l, n]) + (b[l, n, m]), {m});
}
}
}
```
```
loop.prepareForCodegen();
std::cout << *loop.root_stmt() << "\n";
```
```
for (int l = 0; l < 2; l++) {
for (int n = 0; n < 3; n++) {
sum[(0 + l * (1 * 3)) + n * 1] = float(0);
for (int m = 0; m < m_1; m++) {
sum[(0 + l * (1 * 3)) + n * 1] = (sum[(0 + l * (1 * 3)) + n * 1]) + (b[((0 + l * ((1 * m_1) * 3)) + n * (1 * m_1)) + m * 1]);
}
}
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35866
Differential Revision: D20965577
Pulled By: nickgg
fbshipit-source-id: afe506c90db794447180056417013bcaf0e2c049