pytorch
d52f121d - [Composable API]Common _State parent class for composable and wrapper FSDP (#89147)

Commit
2 years ago
[Composable API]Common _State parent class for composable and wrapper FSDP (#89147) **Why this PR?** For the composable APIs implementation, sometimes the internal APIs may not have the application (FSDP, DDP) root module but only the local module. One example is the state_dict/optimizer_state_dict implementation of FSDP. These APIs are designed to start with the root module of the model. It is tricky for these APIs to tell whether a random submodule is managed by either DDP or FSDP. It will be useful to have APIs like: `_get_module_state(module)`: return the composable state if this module is managed by composable API. `_get_module_fsdp_state(module)`: return the FSDP state if this module is managed by FSDP. **What does this PR propose?** 1. Make `_State` out of `_composable` module so that `FullyShardedDataParallel` can inherit from it. 2. A global `_module_state_mapping: Dict[nn.Module, _State]` that keeps the mapping of all submodules (not just root module) to the state. 3. Create `_get_module_state(module)` to look up `_module_state_mapping`. 4. Create `_get_module_fsdp_state(module)` that uses `_get_module_state(module)` to get the state then verifies if the state is `_FSDPState`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89147 Approved by: https://github.com/awgu
Author
Committer
Parents
Loading