[FSDP2] Added backward prefetching (#118118)
This PR adds explicit backward prefetching to overlap communication and computation in backward (namely, needed for `reshard_after_forward=True` or `reshard_after_forward: int`). We do this by recording the post-forward order and using its reverse to approximate the backward order.
This works for the typical 1 forward / 1 backward training. However, for more complex schedules, this can run into some gaps:
- We need to know the _true end of backward_.
- At the true of end of backward, we can clear our recorded post-forward order and pre-backward hook state, and we should wait on gradient reductions.
- There is no easy way to know whether the current backward marks the true end of backward. Therefore, we introduce an API for the user to set this: `fsdp_module.set_is_last_backward(bool)`. For example, for pipeline parallelism's DFS cooldown backward, we can call `fsdp_module.set_is_last_backward(is_last_microbatch)`.
- When the user runs backward through only part of the model, our reverse-post-forward-order heuristic risks _mistargeted prefetches_ for unused modules, which would mean the module's parameters are all-gathered and not freed until the end of backward.
- To error on the side of less memory usage (but no overlap), this PR introduces logic to check whether a module will need its unshard in the current backward (by recording the module's `forward` outputs' `grad_fn`s and querying the autograd engine).
- Note that there may be _no_ overlap in backward for some parts due to no prefetching.
- Note further that when running multiple backwards, if the user does not use `set_is_last_backward`, we may not be able to provide a meaningful error message, as the pre-backward hook could be erroneously cleared on the 1st backward.
- In the future, we may expose more APIs from the autograd engine (similar to `_current_graph_task_execution_order`) to make the prefetching exact. (Currently, `_current_graph_task_execution_order` requires the `with torch.autograd.set_multithreading_enabled(False)`, which is too hard of a constraint as we cannot easily modify users' training loops. We can replace the multi-threading check with a device check. Moreover, in the partial backward case in this PR's unit test, I still hit an [internal assertion](https://github.com/pytorch/pytorch/blob/b816760a2f27adafb0b1dac4c032a2e97c690b29/torch/csrc/autograd/engine.cpp#L476), so some follow-up is required.)
<details>
<summary> Old Discussion </summary>
For discussion:
- The PR includes a counter `expected_backward_unshard_count` to mitigate mistargeted prefetches in backward. However, it can be seen as a necessary but not sufficient solution.
- If a module's outputs do not require gradient, then we certainly do not need to unshard the module in backward.
- However, if a module's outputs do require gradient, then we still may not need to unshard the module for _this_ backward (e.g. if the module did not contribute to `loss` for the current `loss.backward()`).
- This counter will only address the first case but not the second. If we want to address the second, then we may need more info from the autograd engine.
- For now, I did not include any unit test to cover these behaviors, as I do not have a good example yet.
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118118
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
ghstack dependencies: #118017