PyDispatcher integration with functorch (#88785)
This PR teaches PyDispatcher and PyOperator about functorch transforms.
It is important that PyDispatcher/PyOperator dispatch with functorch
transforms, because this is our plan for higher-order operators
(operators that accept functions as arguments). Examples of these
include:
- functorch transforms over the existing cond operator (control flow)
- autograd.Function support for functorch (which I am working towards),
- AOTDispatcher (should be a higher order operator)
Concretely, the problem with teaching PyDispatcher/PyOperator about
functorch is that the stack-based dispatching logic (DynamicLayerStack)
is hidden inside the fallbacks for two dispatch keys
(DynamicLayer{Front, Back}). PyDispatcher doesn't know about C++ boxed
fallbacks, our plan on record for that is that we need to reimplement
all of them in Python (but can call helper functions in C++ to make our
lives easier).
Instead of exposing all of what DynamicLayer{Front, Back} do to python,
this PR takes the approach of re-implementing part of the stack-based
dispatching in Python. The motivation is that this is more sane and
follows what the "ideal" implementation of functorch would have been:
- each transform should be a "mode"
- there should be no TLS dispatch key set hackery. functorch needs to do
this hackery today to re-use VariableType implementations.
This PR:
- exposes the DynamicLayerStack to Python
- The DynamicLayerStack is a stack of Interpreters.
These get exposed to Python as well.
- Interpreters can run operations (Interpreter.process) or lower them to
the next interpreter in the stack (Interpreter.lower)
- To use a PyOperator with functorch transforms, a developer needs to
register a rule for each transform (vmap, grad, jvp, ...).
- The PyOperator API is NOT user-facing. Things like autograd.Function
support for functorch will end up going through the autograd.Function
API.
Question for reviewers:
- Does this design make sense?
- I'm trying to split up the "functorch support for autograd.Function"
work into logical pieces. Would it be better if I didn't? (the full
thing is a bit long - 1000-2000 LOC).
Test Plan:
- new tests that construct PyOperator and compose them with functorch
transforms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88785
Approved by: https://github.com/samdow, https://github.com/soulitzer