pytorch
dd07dab1 - [FSDP][optim_state_dict] Support rank0_only when use_orig_params is on (#99624)

Commit
1 year ago
[FSDP][optim_state_dict] Support rank0_only when use_orig_params is on (#99624) This PR makes `use_orig_params=True` case support rank0_only loading for optim state_dict. The implementation is different from `use_orig_params=False`. The `use_orig_params=False` implementation first flatten the parameters on rank0 and then broadcast the states while this implementation broadcast the state when doing the flattening. The implementation is slower as it broadcast the original parameters instead of the flattened ones. However, the implementation introduced by this PR is simpler. As loading is usually happen once per training life, the performance difference can be ignored. In next PR, we will consolidate the implementations in favor of the simpleness. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99624 Approved by: https://github.com/wz337
Author
Committer
Parents
Loading