add torch.autograd._set_view_replay_enabled, use in aot autograd (#92588)
tldr; this should fix some minor perf regressions that were caused by adding more as_strided() calls in aot autograd.
This PR adds a new context manager, `torch.autograd._set_view_replay_enabled()`.
Context: AOT Autograd has special handling for "outputs that alias graph intermediates". E.g. given this function:
```
def f(x):
y = torch.mul(x, 2)
out = y.view(-1)
return out
```
AOT Autograd will do the following:
```
def fn_to_compile(x):
y = torch.mul(x, 2)
out = y.view(-1)
# return the graph intermediate
return y, out
compiled_fn = compile(fn_to_compile)
def wrapper(x):
y, out = compiled_fn(x)
# regenerate the alias of the graph intermediate
return out._view_func(y)
```
What's annoying is that `out._view_func()` will result in a `.as_strided` call, because `out` is an ordinary runtime tensor. This (likely?) caused a perf regression, because when running the backward, out `as_strided_backward()` is slower than our `view_backward()`.
In this PR, I added some TLS for instructing autograd to do view replay instead of as_strided, even when given a normal tensor. I'm definitely interested in thoughts from autograd folks (cc @albanD @soulitzer). A few points that I want to bring up:
(1) One reason that this API seems generally useful to me is because of the case where you `torch.compile()` a function, and you pass in two inputs that alias each other, and mutate one of the inputs. Autograd is forced to add a bunch of as_strided() calls into the graph when this happens, but this would give users an escape hatch for better compiled perf in this situation
(2) To be fair, AOT Autograd probably won't need this TLS in the long term. There's a better (more complicated) solution, where AOT Autograd manually precomputes the view chain off of graph intermediates during tracing, and re-applies them at runtime. This is kind of complicated though and feels lower priority to implement immediately.
(3) Given all of that I made the API private, but lmk what you all think.
This is a followup of https://github.com/pytorch/pytorch/pull/92255.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92588
Approved by: https://github.com/ezyang, https://github.com/albanD