[DCP] Remove _shard_tensor() call in load_sharded_optimizer_state_dict in optimizer.py (#111096)
`_shard_tensor()` calls into `dist.all_gather_object()` and this is causing optimizer state dict loading to be super slow. Workaround: call `FSDP._shard_utils._create_chunk_sharded_tensor()` to construct ShardedTensor without any communication.
Thanks to @fegin for suggesting the fix!
Thanks @mvpatel2000 for reporting the issue and providing profiling details to help us isolate the problematic source code quickly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111096
Approved by: https://github.com/fegin