[functorch] Allow people to arbitrarily add dispatch keys between DynamicLayer{Front,Back} (pytorch/functorch#843)
Fixes https://github.com/pytorch/functorch/issues/842
The Diagnosis
=============
As Brian pointed out:
For jvp(sub, ...), the chain of dispatch should be:
```
DynamicLayerFrontMode -> at::sub autograd kernel -> DynamicLayerBackMode
```
Instead, what we're doing today is
```
JVP dynamic layer -> at::sub autograd kernel -> at::sub zero_kernel
```
(the zero_tensor kernel errors out, because the inputs are
BatchedTensorImpl objects)
The Problem
=============
functorch's behavior on dispatch keys between DynamicLayerFrontMode and
DynamicLayerBack mode should be:
- upon entering a dynamic layer (aka Interpreter), we zero out all
dispatch keys* between FrontMode and BackMode
- then, the dynamic layer (aka Interpreter) decides to re-enable some
dispatch keys. For example, JVPInterpreter decides to re-enable the
autograd keys
- next, we do a dispatcher call, which will end up hitting one of the
Autograd keys (in the JVPInterpreter case).
The bug is that functorch has a hardcoded list of dispatch keys that it
zeros out. This list does not include ZeroTensor, because before
https://github.com/pytorch/pytorch/pull/77132, the ZeroTensor key was
not between DynamicLayer{Front,Back}Mode.
*There is an exception for autocast and vmapmode, described in the next section.
The Solution
============
Change functorch to programmatically zero out keys between
DynamicLayerBackMode and DynamicLayerFrontMode, with the exception of
Autocast and VmapMode.
This means that in the future, if someone adds a dispatch key between
DynamicLayerBackMode and DynamicLayerFrontMode, we will (probably) be
handling it "correctly": the model for dispatch is:
- [functorch] -> [regular pytorch dispatcher]
- a key like ZeroTensor gets handled in the [regular pytorch dispatcher]
section.
- functorch transforms get handled in the [functorch] section.
We do not change the autocast <-> functorch interaction in this PR
(i.e. functorch does not zero it out) because I'm not sure what the
correct thing to do here is.
We do not change how kVmapMode works because... it needs to be active
to ban random operations in transforms later down the line :/
Test Plan
============
Wait for tests