[inductor] Prevent blowup in inner_fn_str and extract_read_writes (#88933)
Currently the default `ops` handler expects strings as arguments and
just formats them into a function call template string. For complex
expressions, this can lead to exponential growth in terms. Say for
example you have:
```python
def fn(a):
for _ in range(3)
a = ops.mul(a, a)
return a
```
You might expect `inner_fn_str` to contain 1 load and 3 multiplies,
but instead you find 8 loads and 7 multiplies:
```python
load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0)
```
This type of blowup is present in the lowering for
`max_pool2d_with_indices_backward` which in #pytorch/torchdynamo#1352
was reported to have caused the entire compilation to hang.
This PR fixes the issue by formatting the string as a series of assignments to
variables, so for the example above, we now get:
```
tmp0 = load(arg_0, i0)
tmp1 = tmp0 * tmp0
tmp2 = tmp1 * tmp1
tmp3 = tmp2 * tmp2
return tmp3
```
Which corresponds to sequence of `ops` calls made.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88933
Approved by: https://github.com/jansel