pytorch
3bc32799 - PyDispatcher integration with functorch (#88785)

Commit
2 years ago
PyDispatcher integration with functorch (#88785) This PR teaches PyDispatcher and PyOperator about functorch transforms. It is important that PyDispatcher/PyOperator dispatch with functorch transforms, because this is our plan for higher-order operators (operators that accept functions as arguments). Examples of these include: - functorch transforms over the existing cond operator (control flow) - autograd.Function support for functorch (which I am working towards), - AOTDispatcher (should be a higher order operator) Concretely, the problem with teaching PyDispatcher/PyOperator about functorch is that the stack-based dispatching logic (DynamicLayerStack) is hidden inside the fallbacks for two dispatch keys (DynamicLayer{Front, Back}). PyDispatcher doesn't know about C++ boxed fallbacks, our plan on record for that is that we need to reimplement all of them in Python (but can call helper functions in C++ to make our lives easier). Instead of exposing all of what DynamicLayer{Front, Back} do to python, this PR takes the approach of re-implementing part of the stack-based dispatching in Python. The motivation is that this is more sane and follows what the "ideal" implementation of functorch would have been: - each transform should be a "mode" - there should be no TLS dispatch key set hackery. functorch needs to do this hackery today to re-use VariableType implementations. This PR: - exposes the DynamicLayerStack to Python - The DynamicLayerStack is a stack of Interpreters. These get exposed to Python as well. - Interpreters can run operations (Interpreter.process) or lower them to the next interpreter in the stack (Interpreter.lower) - To use a PyOperator with functorch transforms, a developer needs to register a rule for each transform (vmap, grad, jvp, ...). - The PyOperator API is NOT user-facing. Things like autograd.Function support for functorch will end up going through the autograd.Function API. Question for reviewers: - Does this design make sense? - I'm trying to split up the "functorch support for autograd.Function" work into logical pieces. Would it be better if I didn't? (the full thing is a bit long - 1000-2000 LOC). Test Plan: - new tests that construct PyOperator and compose them with functorch transforms Pull Request resolved: https://github.com/pytorch/pytorch/pull/88785 Approved by: https://github.com/samdow, https://github.com/soulitzer
Author
Committer
Parents
Loading