pytorch
5a59bbc1 - [TensorExpr] IRPrinter: show output_args separate from reduce_args when printing ReduceOp. (#37367)

Commit
4 years ago
[TensorExpr] IRPrinter: show output_args separate from reduce_args when printing ReduceOp. (#37367) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37367 Before this change we printed all the args in the same list, for example: ``` BEFORE RFACTOR: { for (int m = 0; m < m_1; m++) { for (int n = 0; n < n_1; n++) { sum[0] = ReduceOp(sum, float(0), (sum[0]) + (b[m, n]), {m, n}); } } } AFTER RFACTOR: { for (int m = 0; m < m_1; m++) { for (int n = 0; n < n_1; n++) { tmp_buf[n] = ReduceOp(tmp_buf, float(0), (tmp_buf[n]) + (b[m, n]), {nm}); # <<< n is out, m is reduce here } } for (int n = 0; n < n_1; n++) { sum[0] = ReduceOp(sum, float(0), (sum[0]) + (tmp_buf[n]), {n}); } } ``` With this change we explicitly show which args are reduce args: ``` BEFORE RFACTOR: { for (int m = 0; m < m_1; m++) { for (int n = 0; n < n_1; n++) { sum[0] = ReduceOp(sum, float(0), (sum[0]) + (b[m, n]), out_args={}, reduce_args={m, n}); } } } AFTER RFACTOR: { for (int m = 0; m < m_1; m++) { for (int n = 0; n < n_1; n++) { tmp_buf[n] = ReduceOp(tmp_buf, float(0), (tmp_buf[n]) + (b[m, n]), out_args={n}, reduce_args={m}); } } for (int n = 0; n < n_1; n++) { sum[0] = ReduceOp(sum, float(0), (sum[0]) + (tmp_buf[n]), out_args={}, reduce_args={n}); } } ``` Test Plan: Imported from OSS Differential Revision: D21265807 Pulled By: ZolotukhinM fbshipit-source-id: 384396cd55562570f8e33657b856a4404d451080
Author
Mikhail Zolotukhin
Parents
Loading