[TensorExpr] IRPrinter: show output_args separate from reduce_args when printing ReduceOp. (#37367)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37367
Before this change we printed all the args in the same list, for example:
```
BEFORE RFACTOR:
{
for (int m = 0; m < m_1; m++) {
for (int n = 0; n < n_1; n++) {
sum[0] = ReduceOp(sum, float(0), (sum[0]) + (b[m, n]), {m, n});
}
}
}
AFTER RFACTOR:
{
for (int m = 0; m < m_1; m++) {
for (int n = 0; n < n_1; n++) {
tmp_buf[n] = ReduceOp(tmp_buf, float(0), (tmp_buf[n]) + (b[m, n]), {nm}); # <<< n is out, m is reduce here
}
}
for (int n = 0; n < n_1; n++) {
sum[0] = ReduceOp(sum, float(0), (sum[0]) + (tmp_buf[n]), {n});
}
}
```
With this change we explicitly show which args are reduce args:
```
BEFORE RFACTOR:
{
for (int m = 0; m < m_1; m++) {
for (int n = 0; n < n_1; n++) {
sum[0] = ReduceOp(sum, float(0), (sum[0]) + (b[m, n]), out_args={}, reduce_args={m, n});
}
}
}
AFTER RFACTOR:
{
for (int m = 0; m < m_1; m++) {
for (int n = 0; n < n_1; n++) {
tmp_buf[n] = ReduceOp(tmp_buf, float(0), (tmp_buf[n]) + (b[m, n]), out_args={n}, reduce_args={m});
}
}
for (int n = 0; n < n_1; n++) {
sum[0] = ReduceOp(sum, float(0), (sum[0]) + (tmp_buf[n]), out_args={}, reduce_args={n});
}
}
```
Test Plan: Imported from OSS
Differential Revision: D21265807
Pulled By: ZolotukhinM
fbshipit-source-id: 384396cd55562570f8e33657b856a4404d451080
Author
Mikhail Zolotukhin