pytorch
5ea418bf - [FSDP][3/N] Move `fsdp_modules(root_only=True)` -> `_get_fsdp_root_states()` (#90862)

Commit
2 years ago
[FSDP][3/N] Move `fsdp_modules(root_only=True)` -> `_get_fsdp_root_states()` (#90862) - This PR introduces `_get_fsdp_root_states(state: _FSDPState, module: nn.Module)` to return all states that are FSDP root in the module tree rooted at `module`. - This requires passing in both `state` and `module` because it must call `_lazy_init()` to check for root-ness, which requires that signature. - This PR moves the one internal usage of `FullyShardedDataParallel.fsdp_modules(root_only=True)` to use `_get_fsdp_root_states()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90862 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
Loading