aot autograd: handle detach() and no_grad() mutations on input (#95980)
Fixes https://github.com/pytorch/pytorch/issues/95167
More details are in that issue. To summarize, the issue shows up when we have some code like this:
```
def f(x):
x.detach().mul_(2) # can also happen if the mul_() happens under torch.no_grad()
return x + 1
```
AOTAutograd will then spit out code like this:
```
def compiled_fn(x):
x_updated = x.mul(2)
out = x_updated + 1
return x_updated, out
def CompiledFunction.forward(x): # pseudocode, this is part of an autograd.Function
x_updated, out = compiled_function(x):
return x_updated, out
def runtime_wrapper(x):
x_updated, out = CompiledFunction.apply(x)
x.copy_(x_updated)
x = torch.ones(2, requires_grad=True)
out = runtime_wrapper(x)
```
However, the call to `x.copy_(x_updated)` will fail with the error: `a leaf Variable that requires grad is being used in an in-place operation`. This is because `x` is an autograd leaf, and autograd doesn't allow you to mutate leaves.
In this case though, the data mutation should be entirely opaque to autograd - all mutations happened underneath a `.detach()` or a `torch.no_grad()`.
As Ed pointed out in the issue, we can detect this situation by checking if the mutated input is an autograd leaf. If it is, then it must have been the case that any mutations on it must have been hidden from autograd, since otherwise the eager code would have error'd. The solution I added is to detect this situation, and manually run `x.detach().copy_(x_updated)`, to hide the update from autograd.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95980
Approved by: https://github.com/ezyang