pytorch
9f4e7dec - [FSDP] Add re-key btw param names/IDs for optim state dict (#74912)

Commit
2 years ago
[FSDP] Add re-key btw param names/IDs for optim state dict (#74912) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74912 **Overview** This introduces a new static method `FSDP.rekey_optim_state_dict()` as a utility for interoperating between local/DDP (non-wrapped) models and FSDP (wrapped) models. To load from a wrapped model to a non-wrapped model: ``` wrapped_model, wrapped_optim = ... full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) nonwrapped_model, nonwrapped_optim = ... rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) nonwrapped_optim.load_state_dict(rekeyed_osd) ``` To load from a non-wrapped model to a wrapped model: ``` nonwrapped_model, nonwrapped_optim = ... osd = nonwrapped_optim.state_dict() rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) wrapped_model, wrapped_optim = ... sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) wrapped_optim.load_state_dict(sharded_osd) ``` **Test Plan** `test_rekey_optim_state_dict_to_ids()` and `test_rekey_optim_state_dict_to_names()`. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D35225819 Pulled By: awgu fbshipit-source-id: fbbdbde8b595a9c65b17a9aecb4f22b2c9761a23 (cherry picked from commit dba4ef949074afa8ae858d88a6c089ce0f6f1229)
Author
Committer
Parents
Loading