pytorch
95ee5fec - [FSDP][1/N] Add `_get_fsdp_states()` (#90860)

Commit
2 years ago
[FSDP][1/N] Add `_get_fsdp_states()` (#90860) - This PR introduces `_get_fsdp_states(module: nn.Module) -> List[_FSDPState]` to prepare for `fully_shard` manual "wrapping". - ~~I place it in `_runtime_utils.py`, not `_common_utils.py`, because in a follow-up PR, I will add `_get_root_fsdp_states()`, which requires `_lazy_init()`. I concluded that it would be preferred to have both of these getters be in the same place than to have them split, even if that means that `_get_fsdp_states()` is in `_runtime_utils.py`.~~ Due to circular import issues, I think I should still put it in `_common_utils.py`. - This PR changes `FullyShardedDataParallel.fsdp_modules()` to be backed by `_get_fsdp_states()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90860 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
Loading