[custom_op] explicit autograd API (#101824)
This PR adds an explicit API for registering a backward formula for a
CustomOp. In the end state, we will likely have this explicit API and a
magic API (which is sugar on top of an explicit API), since different
parties of users prefer different ones.
Concretely, to define a backward formula for a CustomOp:
- a user must provide us a "save for backward" function that accepts
(inputs, output) and returns exactly what they want saved for backward
- a user must provide us a "backward" function that accepts
(ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs
are returned as a dict mapping str to a gradient.
Please see the changes in custom_op_db.py for examples of the API.
There are a number of pieces to this PR and I'm happy to split it if it
helps. They are:
- The actual APIs for specifying the two functions
(impl_save_for_backward, impl_backward)
- The autograd kernel: we take the functions the user give us and
construct an autograd.Function object that we then register to
the Autograd dispatch key
- Indirection for the autograd kernel. We add a layer of indirection so
that one can swap out the autograd kernel. This is necessary because by
default, we register an "autograd not implemented" kernel as the
Autograd implementation but then swap it for the actual kernel when the
user provides it.
Test Plan:
- We apply this API to give backward formulas for things in
custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests.
- Various tests in test_python_dispatch.py to check error cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101824
Approved by: https://github.com/ezyang