pytorch
10990734 - [FSDP][2/N] `_summon_full_params` -> `_unshard_params` (#92297)

Commit
1 year ago
[FSDP][2/N] `_summon_full_params` -> `_unshard_params` (#92297) **Overview** This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring. - The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`. - This PR refactors: - `summon_full_params()` core logic to `_unshard_params()` - `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument - Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state **Details** - This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration. - We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered. - Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity. **Follow-Ups** - `writeback=True` and `rank0_only=True` raises an error. The previous explanation was: > is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited. I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/92297 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
Loading