[Reland][Autograd/Checkpoint] Checkpoint implementation without reentrant autograd (#69508)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69508
Original Phabricator Diff: D32704467 (https://github.com/pytorch/pytorch/commit/e032dae32904c6e90353e90051785514d6c5a7d9)
Reland, fix is to not test traditional checkpoint when input does not require grad as that is unsupported as documented.
Original PR body:
Resubmission of https://github.com/pytorch/pytorch/pull/62964 with the
suggestions and tests discussed in
https://github.com/pytorch/pytorch/issues/65537.
Adds a `use_reentrant=False` flag to `checkpoint` function. When
`use_reentrant=True` is specified, a checkpointing implementation that uses
SavedVariableHooks instead of re-entrant autograd is used. This makes it more
composable with things such as `autograd.grad` as well as DDP (still need to
add thorough distributed testing).
As discussed in https://github.com/pytorch/pytorch/issues/65537, the tests that we need to add are:
- [x] Gradient hooks are called once
- [x] works when input does require grads but Tensor that require grads are captures (like first layer in a nn)
- [x] works for functions with arbitrary input/output objects
- [x] distributed tests (next PR)
Note that this is only for `torch.utils.checkpoint`, if this approach overall looks good, we will do something similar for `checkpoint_sequential`.
ghstack-source-id: 144948501
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D32902634
fbshipit-source-id: 2ee87006e5045e5471ff80c36a07fbecc2bea3fe