pytorch
4c1bc91f - Support autograd.Function w/ grad (#99483)

Commit
1 year ago
Support autograd.Function w/ grad (#99483) This PR adds support for tracing autograd.Function with grad. A few important bullet points outlining our approach: 1) Our goal is to verify soundness in order to add a call_function to the autograd.Function's `apply` to the graph. 2) We achieve (1) by either verifying soundness or rejecting soundness, by ensuring that both forward and backward of the autograd.Function are sound. 3) For the forward, if we verify soundness, we install its guards into the graph. 4) For the backward, if we verify soundness, we throw it out. However, backwards soundness verification is more onerous, and has a config driven set of banned attrs and methods for tensors. 1-4 above are achieved by turning the forward and backward into UserDefinedFunctionVariables, and inlining through them, relying on dynamo's soundness detection. If we graph break in these, we raise and treat them as unsound. As noted above, backwards is stricter yet. For the tracing, the safety comes from dynamo's HigherOrderOperator system. That system ensures that not only do we trace soundly, but that no new variables are lifted into inputs during the tracing, and that the forward and backwards are entirely self contained. Whenever we reject a function as unsound, we restore back, as usual. Due to some limitations in the lifting logic, we have an escape hatch we implemented for tensors that are known in forward, but cross into backwards through save_tensors (save) /saved_tensors (load). We escape hatch here to avoid having the known saved tensors coming from forward end up being accidentally treated as lifted variables (and rejected). This is sound, but a little hacky feeling. Additionally, due to some limitations in fx node removal, combined with how we produce subgraphs for the traces installed from HigherOrderOperators, we had to improve our node removal logic. In the event of a restore, we remove the old nodes from the graph, as usual in dynamo. However, because the references to these nodes may exist in subgraphs, we traverse any nodes users and remove them first if and only if they are in another graph. This is always sound, because removal should only be downstream of restoration at this point. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99483 Approved by: https://github.com/zou3519
Author
Committer
Parents
Loading