pytorch
81f351ac - [inductor] Prevent blowup in inner_fn_str and extract_read_writes (#88933)

Commit
1 year ago
[inductor] Prevent blowup in inner_fn_str and extract_read_writes (#88933) Currently the default `ops` handler expects strings as arguments and just formats them into a function call template string. For complex expressions, this can lead to exponential growth in terms. Say for example you have: ```python def fn(a): for _ in range(3) a = ops.mul(a, a) return a ``` You might expect `inner_fn_str` to contain 1 load and 3 multiplies, but instead you find 8 loads and 7 multiplies: ```python load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) * load(arg_0, i0) ``` This type of blowup is present in the lowering for `max_pool2d_with_indices_backward` which in #pytorch/torchdynamo#1352 was reported to have caused the entire compilation to hang. This PR fixes the issue by formatting the string as a series of assignments to variables, so for the example above, we now get: ``` tmp0 = load(arg_0, i0) tmp1 = tmp0 * tmp0 tmp2 = tmp1 * tmp1 tmp3 = tmp2 * tmp2 return tmp3 ``` Which corresponds to sequence of `ops` calls made. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88933 Approved by: https://github.com/jansel
Author
Committer
Parents
Loading