pytorch
e0eaa95e - [DCP] Remove _shard_tensor() call in load_sharded_optimizer_state_dict in optimizer.py (#111096)

Commit
1 year ago
[DCP] Remove _shard_tensor() call in load_sharded_optimizer_state_dict in optimizer.py (#111096) `_shard_tensor()` calls into `dist.all_gather_object()` and this is causing optimizer state dict loading to be super slow. Workaround: call `FSDP._shard_utils._create_chunk_sharded_tensor()` to construct ShardedTensor without any communication. Thanks to @fegin for suggesting the fix! Thanks @mvpatel2000 for reporting the issue and providing profiling details to help us isolate the problematic source code quickly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111096 Approved by: https://github.com/fegin
Author
Committer
Parents
Loading