Make make_dual redispatch (#68630)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68630
Constraints:
1) (functorch) if all the inputs to an op have requires_grad=False and don't have tangents, then their VariableType
kernel should be a no-op i.e., behave like a redispatch. This is due to functorch's DynamicLayerStack
having the autograd key by default (which is so that transformations like vmap) still work with autograd
2) (inference mode) inference tensors in inference mode will call straight into the kernel, we should still do something sensible
inside even if we normally wouldn't redispatch into it.
3) ~Should support potential application of interposition below autograd: `nn.Parameter` is a example of subclassing where the subclass
is not preserved when an operation is performed. There is an exception though: we want calling `make_dual` on a
`nn.Parameter` to preserve its parameterness.~
4) Should avoid calls to shallow_copy_and_detach to avoid spurious calls into `__python_dispatch__`.
This PR:
- does not redispatch to `make_dual` from its `ADInplaceOrView` kernel to satisfy (1)
- calls into `alias` from the kernel in the native namespace so that behavior is consistent with other views in inference mode to satisfy (2)
- discussion of (3). We still wouldn't be able to directly override `make_dual` below autograd. In this PR, instead of not redispatching at all, we choose to redispatch into `at::alias` so that one can override `make_dual`. The side effect is that one would not be able to distinguish calls between the two, which can be problematic (though a straightforward but hacky solution would be to create a new `at::alias_for_make_dual` that would allow users to distinguish) the two. This isn't ideal but seems to be the simplest way to satisfy (3). We don't pursue that hacky solution here.
- (4) is satisfied because we remove calls to `shallow_copy_and_detach`
<details>
<summary> A potentially less hacky but more involved solution? (WIP) </summary>
Realizing that make_dual is more like requires_grad, perhaps it shouldn't be autograd explicit? Make make_dual a composite or python-only construct. i.e., it would be a view on the primal followed by something to the effect of primal.set_fw_grad(tangent).
Additional constraints:
5) make_dual needs to be backward-differentiable (I can't think of any applications yet becuase
technically as a high-order function, jvp's input is the tangent only, "detach" is not applied on
the tangent, so one would still be able to propagate gradients through it).
6) set_fw_grad needs to raise an error if there is a layout mismatch and base is a forward-differnentiable view
Possible plan
- (6) implies that a plain view would not suffice. We need a `detach`-like operation to ensure that set_fw_grad
knows the view is not forward differentiable.
- (5) implies that is this (new) `detach` would need to be backward differentiable (API TBD).
- (3) is no longer relevant because make_dual is no longer autograd explicit, but perhaps this new detach should behave like the current one? There is a lot of logic to replicate for detach, so this may be hard.
- (1) is satisfied if we use current detach logic, i.e., , and (4) is trivial.
I'm not convinced that this is the right solution either, because in the end does (3) still work?
</details>
Test Plan: Imported from OSS
Reviewed By: jbschlosser
Differential Revision: D32899679
Pulled By: soulitzer
fbshipit-source-id: 98e13ae954e14e1e68dbd03eb5ab3300d5ed2c5e