pytorch
ef512db0 - [inductor] Constant and index_expr propagation pass (#101077)

Commit
1 year ago
[inductor] Constant and index_expr propagation pass (#101077) This pass does a limited form of constant propagation, as well as propagation of sympy indexing expressions. For example, say you have the function: ```python def flip(x): i = torch.arange(x.size(0) - 1, -1, -1, device=x.device) return x[i] ``` On current main this results in indirect indexing: ```python class buf0_loop_body: var_ranges = {z0: 4, z1: 3} index0 = 3 - z0 index1 = 3*indirect0 + z1 index2 = 3*z0 + z1 def body(self, ops): get_index = self.get_index('index0') index_expr = ops.index_expr(get_index, torch.int64) set_indirect0 = self.set_indirect0(index_expr) get_index_1 = self.get_index('index1') load = ops.load('arg0_1', get_index_1) get_index_2 = self.get_index('index2') store = ops.store('buf0', get_index_2, load, None) return store ``` With this PR the indexing is propagated through the computation and into direct indexing: ```python class buf0_loop_body: var_ranges = {z0: 4, z1: 3} index0 = -3*z0 + z1 + 9 index1 = 3*z0 + z1 def body(self, ops): get_index = self.get_index('index0') load = ops.load('arg0_1', get_index) get_index_1 = self.get_index('index1') store = ops.store('buf0', get_index_1, load, None) return store ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/101077 Approved by: https://github.com/lezcano, https://github.com/ngimel
Author
Committer
Parents
Loading