pytorch
6c30dc6c - [FSDP] Save `_all_handles`; `_all_fsdp_states` to root (#95465)

Commit
2 years ago
[FSDP] Save `_all_handles`; `_all_fsdp_states` to root (#95465) - The previous PR addressed one tree traversal in `_root_pre_forward()` but not the main one from `_get_fsdp_handles()` that runs for all settings. - This PR saves `_all_handles` to cache `_get_fsdp_handles()` and `_all_fsdp_states` to cache `_get_fsdp_states()` (renamed from `_fsdp_states` compared to last PR) on the root state. - This PR introduces a dummy `_RootFSDPState` class that inherits from `_FSDPState` to be used only for type checking since some attributes are only defined for root states. - I found this approach to be better than adding `_p_assert(state.root_only_attr is not None, ...)` upon each usage of `root_only_attr`. - This hopefully also helps readers to quickly see which attributes are defined only on root states. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95465 Approved by: https://github.com/fduwjj
Author
Andrew Gu
Committer
Parents
Loading