[discussion] fix for aot autograd outputs that dont require grad (#86838)
Fixes https://github.com/pytorch/functorch/issues/1052
I got here after some discussion with Alban. Today, if you aot_function() trace a program where some of its inputs have `requires_grad=True`, but some outputs are expected to have `requires_grad=False`, we will incorrectly set all outputs to have `requires_grad=True`.
A simple solution is to use autograd.function's API for marking outputs as non-differentiable, based on what we witnessed when we traced the forward.
This will make the `autograd.Function` that we return **wrong**, if you created it using inputs that required grad, and tried to re-use it with inputs that have different `requires_grad` field. But as long as we're hiding behind dynamo, which should guard on requires_grad, then we'll re-run `aot_function()` and get out a new compiled function that does the right thing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86838
Approved by: https://github.com/ezyang