pytorch
24469020 - [FSDP] Use _init_from_local_tensor to create ShardedTensor to avoid communication overhead (#82911)

Commit
2 years ago
[FSDP] Use _init_from_local_tensor to create ShardedTensor to avoid communication overhead (#82911) FSDP originally uses `_init_from_local_shards_and_global_metadata()` to create a ShardedTensor for sharded_state_dict(). We have seen some non-trivial overhead if the number of tensors is large. Using `_init_from_local_shards_and_global_metadata ` can significantly reduce the overhead. For a model with ~250 tensors in the state_dict trained with 16 GPUs, the original `sharded_state_dict` takes ~1.7 seconds and this PR reduces the overhead to ~0.6 seconds. Differential Revision: [D38452170](https://our.internmc.facebook.com/intern/diff/D38452170/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/82911 Approved by: https://github.com/awgu
Author
Committer
Parents
Loading