[FSDP] Fix `use_orig_params=True` + AC (#87413)
Without this change, the post-backward hooks do not run when using reentrant activation checkpointing.
**Explanation**
FSDP registers the original parameters as plain `Tensor`s in the forward pass so that their ops are tracked by autograd to ensure proper gradient propagation into the `FlatParameter`s. FSDP registers the post-backward hooks in its pre-forward.
For `use_orig_params=True`, FSDP replaces the plain `Tensor`s with the sharded `nn.Parameter`s in the post-forward when resharding. This differs from `use_orig_params=False`, which keeps the plain `Tensor`s registered as attributes, except their data are freed, meaning that accessing them between forward and backward errors. Before this PR, for `use_orig_params=True`, FSDP simply restores the unsharded original parameter data in the pre-backward to enable correct gradient computation. However, this does not suffice for reentrant activation checkpointing (AC), where the recomputed forward happens after FSDP's pre-backward and the ops in the recomputed forward must be tracked by autograd.
My initial solution was to simply have FSDP restore the original parameters as plain `Tensor`s again in the pre-backward so that they would be tracked by autograd exactly like the normal forward. However, this seems to not suffice in general. The `FlatParameter`'s `AccumulateGrad` object may change after the original pre-forward when performing a recomputed forward.
The new approach in this PR is to follow the `use_orig_params=False` way -- namely, to preserve the plain `Tensor` variables across forward and backward. I achieved this by saving the variables explicitly in the forward and restoring them in the pre-backward. I clear them in the post-backward to avoid the dangling references (though, I do not think this is strictly necessary).
An alternative approach I considered is using forward hooks. However, this does not change the order of operations across FSDP, checkpoint, and the wrapped module, so it does not work. (As long as the order is FSDP(checkpoint(module)), then registered hooks still happen either before or after the checkpoint recomputation -- we cannot insert logic to run inside the checkpoint recomputation.)
**Test Plan**
I augmented the existing reentrant checkpointing unit tests to also test `use_orig_params=True`. I also verified that the pycls model does not error (even with the new approach).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87413
Approved by: https://github.com/rohan-varma