pytorch
fce67800 - [TensorExpr] Extend arithmetic simplifier to work with multi variable expressions (#35127)

Commit
4 years ago
[TensorExpr] Extend arithmetic simplifier to work with multi variable expressions (#35127) Summary: A new version of the IR simplifier used by the jit/tensorexpr fuser. This is capable of simplifying expressions containing (shock) multiple variables, eg: ```(m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) => 1``` Similar to the previous IR Simplifier it uses a two stage approach: 1. Traverse the tree combining subtree's of commutable operations in to a flat structure. In this implementation we have two intermediate Exprs: Term (expressing products of sub expressions) and Polynomial (expressing sums of sub expressions). 2. Traverse the tree expanding Term's and Polynomials into their component operators. Using the example above we execute with a process like this to simplify: ``` (m * (1 * n_1) + (n + 1)) - (m * (1 * n_1) + n) # Using PolynomialTransformer: => Sub(Add(Mul(m, Mul(1, n_1)), Add(n, 1)), Add(Mul(m, Mul(1, n_1)), n)) => Sub(Polynomial(Term(m, n_1), n, 1), Polynomial(Term(m, n_1), n)) => Polynomial(Term(m, n_1), Term(-1, m, n_1), n, -n, 1) => Polynomial(1) # Using TermExpander => 1 ``` The IRSimplifier supports arithmetic simplifications of operators Add, Sub and Mul and constant folding of all binary Exprs and Intrinsics, but does not attempt expansion of multiplication of Polynomials to the canonical form since that generally leads to less efficient representations. It will do scalar factorization if it results in removal of operators, and will merge chains of multilane primitives (such as Broadcast and Ramp) down into a single operator. The ir_simplifier unit tests are a short tour of its capabilities. The existing simplifier has a bug where it will sometimes reorder operations on floating point types which are not associative. This causes (at least) the pyhpc equation_of_state benchmark to produce incorrect results. I have fixed that issue in this version and verified that that benchmark produces the same results with and without the simplifier. Tests: all cpp & py tensorexpr tests, and pyphc benchmark: ``` benchmarks.equation_of_state ============================ Running on CPU size backend calls mean stdev min 25% median 75% max Δ ------------------------------------------------------------------------------------------------------------------ 4,194,304 pytorch 10 0.246 0.002 0.243 0.245 0.246 0.248 0.250 1.000 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/35127 Differential Revision: D20624571 Pulled By: nickgg fbshipit-source-id: e49049377beee69e02dcf26eb922bef1447ae776
Author
Parents
Loading