Make aot_autograd explicitly error when double backward (#92348)
Mitigates https://github.com/pytorch/pytorch/issues/91469
Changes:
- ~once_differentiable can now be parametrized to print a custom error message~
- instead of once_differentiable, we do the backward inside another custom Function, which makes sure the graph is connected, but also makes sure to error on double backward
- we now explicitly error when doing double backward with torch.compile + aot_autograd instead of being silently incorrect. ~The niceness of the error message can vary depending on whether your grad_outputs are passed, or whether you are doing `.grad()` or `.backward()`.~
Unchanged:
- doing backward inside compiled function is still allowed. It currently causes a graph break and is equivalent to doing backward outside the compiled function. It might be nice to disallow this explicitly as well, but that can be done in a follow up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92348
Approved by: https://github.com/albanD