Make detach redispatch like a regular PyTorch operator (#71707)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71707
Why?
- detach should behave like jax.stop_gradient in functorch. Because it
does not detach all the way through, functorch (as well as a Tensor
Subclass wrapping a Tensor subclass) won't see it after the first
layer/subclass handles it.
How?
- This PR changes detach to dispatch all the way through to the backend.
- This PR also modifies native::detach to call shallow_copy_and_detach
instead of native::alias. This is because today, the semantics of detach
and alias are differently -- they differ only by
allow_tensor_metadata_change. In the future, we may choose to deprecate
this flag.
- NB: Before and after this PR, detach() shows up twice in
torch_dispatch: https://github.com/pytorch/pytorch/issues/71725. This is
not a regression so I didn't want to fix it in this PR because it is
weird to fix.
Test Plan: - added new tests; run existing tests
Reviewed By: albanD
Differential Revision: D33752860
Pulled By: zou3519
fbshipit-source-id: 40cc2dc8232e75a02586a4ba5b0ef5f16cb76617
(cherry picked from commit f88aae426ec00bba907e9ad5d1cd6ed2c40bf14a)