[inductor] Constant and index_expr propagation pass (#101077)
This pass does a limited form of constant propagation, as well as propagation of
sympy indexing expressions. For example, say you have the function:
```python
def flip(x):
i = torch.arange(x.size(0) - 1, -1, -1, device=x.device)
return x[i]
```
On current main this results in indirect indexing:
```python
class buf0_loop_body:
var_ranges = {z0: 4, z1: 3}
index0 = 3 - z0
index1 = 3*indirect0 + z1
index2 = 3*z0 + z1
def body(self, ops):
get_index = self.get_index('index0')
index_expr = ops.index_expr(get_index, torch.int64)
set_indirect0 = self.set_indirect0(index_expr)
get_index_1 = self.get_index('index1')
load = ops.load('arg0_1', get_index_1)
get_index_2 = self.get_index('index2')
store = ops.store('buf0', get_index_2, load, None)
return store
```
With this PR the indexing is propagated through the computation and into direct
indexing:
```python
class buf0_loop_body:
var_ranges = {z0: 4, z1: 3}
index0 = -3*z0 + z1 + 9
index1 = 3*z0 + z1
def body(self, ops):
get_index = self.get_index('index0')
load = ops.load('arg0_1', get_index)
get_index_1 = self.get_index('index1')
store = ops.store('buf0', get_index_1, load, None)
return store
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101077
Approved by: https://github.com/lezcano, https://github.com/ngimel