[FSDP] Expose internal prefetch limits (#86198)
This PR refactors the prefetching implementation to enable a module to prefetch more than one all-gather.
- The motivation is for backward prefetching, but forward prefetching is included in the change as well.
- The prefetching limit is a _limit_. In some edge cases (e.g. dynamic graph or first/last module), the limit may not be reached.
- The prefetching limit is kept as internal in this PR -- it is set as local variables `backward_prefetch_limit` and `forward_prefetch_limit` in the `FullyShardedDataParallel` constructor and passed to the `_ExecOrderData()` constructor.
- This PR additionally includes some clean up for forward prefetching but does not change any semantics assuming static graph.
If we increase the `backward_prefetch_limit` to `2`, then a typical pattern may be that the first module in the pre-backward prefetches 2, but every next module only prefetches 1 since its first target was already prefetched by the previous. If we did not do this behavior, then with more modules, the prefetching would run further and further ahead.
**`_handles_prefetched`**
- This is used to avoid multiple modules prefetching the same handles keys.
- `_handles_prefetched[handles_key]` is set to `True` when the prefetch for `handles_key` happens from the CPU thread (`_prefetch_handles()`).
- `_handles_prefetched[handles_key]` is set to `False` when any handle in `handles_key` is resharded (`_reshard()`).
- `_handles_prefetched` is cleared at the end of the backward (`_wait_for_post_backward()`).
**`_needs_pre_backward_unshard`**
- This is used to determine if a handles key should be backward prefetched at all.
- `_needs_pre_backward_unshard[handles_key]` is set to `False` in the post-forward (`_register_pre_backward_hooks()`).
- `_needs_pre_backward_unshard[handles_key]` is set to `True` in the post-forward if the forward outputs include tensors that require gradient (`_register_pre_backward_hook()`).
- `_needs_pre_backward_unshard[handles_key]` is set to `False` in the pre-backward hook, after unsharding (`_pre_backward_hook()`).
**`_needs_pre_forward_unshard`**
- This is used to determine if a handles key should be forward prefetched at all.
- `_needs_pre_forward_unshard[handles_key]` is set to `True` in the root's pre-forward (`_fsdp_root_pre_forward()`).
- `_needs_pre_forward_unshard[handles_key]` is set to `False` in the pre-forward unshard (`_pre_forward_unshard()`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86198
Approved by: https://github.com/zhaojuanmao