pytorch
71a67d0c - [sharded_tensor] simplify init_from_local_shards API (#64481)

Commit
3 years ago
[sharded_tensor] simplify init_from_local_shards API (#64481) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64481 This simplifies `init_from_local_shards` API in sharded tensor, to only require user pass in a list of `Shard` and `overall_size`, instead of ShardedTensorMetadata. We will do the all_gather inside to form a valid ShardedTensorMetadata instead. TODO: add more test cases to improve coverage. ghstack-source-id: 141742350 Test Plan: TestShardedTensorFromLocalShards Reviewed By: pritamdamania87 Differential Revision: D30748504 fbshipit-source-id: 6e97d95ffafde6b5f3970e2c2ba33b76cabd8d8a
Author
Parents
Loading