pytorch
58c9d521 - [FSDP] Implement sharded_state_dict and load_sharded_state_dict

Commit
2 years ago
[FSDP] Implement sharded_state_dict and load_sharded_state_dict Pull Request resolved: https://github.com/pytorch/pytorch/pull/77356 Implement ShardedTensor compatible sharded_state_dict() and load_sharded_state_dict(). Algorithm overview: sharded_state_dict(): 1. Call summon_full_parameters(). 2. For each unflattened, non-sharded parameter. 2.1 Call chunk() to get the local shard of the parameter. 2.2 Create a ShardedTensor. 3. Replace the tensor in the state_dict with the newly created ShardedTensor. load_sharded_state_dict(): 1. For each unflattened, sharded parameter (ShardedTensor) in the given state_dict: 1.1 Pop out from the state_dict. 1.2 Do allgather to reconstruct the unflattened, non-sharded parameter. 2. Create a FlatParameter with the unflattened, non-sharded parameters. 3. Shard the newly created FlatParameter. 4. Insert the new FlatParameter into the state_dict. Differential Revision: [D36284983](https://our.internmc.facebook.com/intern/diff/D36284983/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36284983/)! Approved by: https://github.com/zhaojuanmao
Author
Committer
Parents
Loading