[FSDP] Implement sharded_state_dict and load_sharded_state_dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77356
Implement ShardedTensor compatible sharded_state_dict() and load_sharded_state_dict().
Algorithm overview:
sharded_state_dict():
1. Call summon_full_parameters().
2. For each unflattened, non-sharded parameter.
2.1 Call chunk() to get the local shard of the parameter.
2.2 Create a ShardedTensor.
3. Replace the tensor in the state_dict with the newly created ShardedTensor.
load_sharded_state_dict():
1. For each unflattened, sharded parameter (ShardedTensor) in the given state_dict:
1.1 Pop out from the state_dict.
1.2 Do allgather to reconstruct the unflattened, non-sharded parameter.
2. Create a FlatParameter with the unflattened, non-sharded parameters.
3. Shard the newly created FlatParameter.
4. Insert the new FlatParameter into the state_dict.
Differential Revision: [D36284983](https://our.internmc.facebook.com/intern/diff/D36284983/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36284983/)!
Approved by: https://github.com/zhaojuanmao