pytorch
b01a15d3 - [TensorExpr] Redesign Rfactor loopnest transformation. (#55324)

Commit
3 years ago
[TensorExpr] Redesign Rfactor loopnest transformation. (#55324) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55324 With this change `rfactor` only affects the passed loop and its body never touching anything outside (that was a rootcause of a bug with the previous implementation). Also, we don't have an `insertion_point` parameter anymore - its meaning was vague, and the effect of it should've been achievable with other transformations anyway. The new `rfactor` semantics is as follows: ``` Requirements: * S is the reduction store * S is the only statement in the innermost loop * There is at least two reduction arguments in S * OUTER_REDUCTION_FOR loop corresponds to the outermost reduction variable used in the store and all other reduction variables are index variables of children loops of OUTER_REDUCTION_FOR * OUTER_REDUCTION_FOR is a perfect loop nest, i.e. it has only loops corresponding to the other reduction variables and the store, nested into each other What it does: * Introduce a new buffer with an extra dimension of a size equal to the span of the loop OUTER_REDUCTION_FOR (the new buffer is returned via RFAC_BUF_PTR) * Insert an initialization store for the new buffer in OUTER_REDUCTION_FOR before its nested loop * Replace the reduction store to the original buffer with the reduction store to the temp buffer, removing the index var of OUTER_REDUCTION_FOR from reduction arguments * Insert a final reduction store over the extra dimension of the new buffer to the original buffer * Returns TRUE if the transformation succeeded and FALSE otherwise Example: Original IR: S1: for i # normal axis S2: X[i] = 0 S3: for j # reduction axis S4: for k # reduction axis S5: X[i] = ReduceOp(X[i] + Y[i,j,k], reduce_axis={j,k}) After RFACTOR(S5, S3) S1: for i # normal axis S2: X[i] = 0 S3: for j # reduction axis for X, normal axis for X_rfac X_rfac[i,j] = 0 S4: for k # reduction axis X_rfac[i,j] = ReduceOp(X_rfac[i,j] + Y[i,j,k], reduce_axis={k}) X[i] = ReduceOp(X[i] + X_rfac[i,j], reduce_axis={j}) ``` Differential Revision: D27694960 Test Plan: Imported from OSS Reviewed By: navahgar Pulled By: ZolotukhinM fbshipit-source-id: 076fa6a1df2c23f5948302aa6b43e82cb222901c
Author
Mikhail Zolotukhin
Parents
Loading