add gather to ShardedTensor (#65671)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65671
Tentative implementation to use dist.gather_object to collect shards from all ranks and then "merge" them. The merge is done on dst_rank though padding the sharded tensors into the size of full tensor based on their metadata (offsets, lengths) first, and then summing these padded tensors together.
Also considered concatenating sharded tensor without padding to minimize memory footprint (assuming padding will increase memory). But it may not be flexible enough for arbitrary sharing (e.g. shard on multiple directions)
Another way can be constructing the padded tensor on each rank and reduce to rank0. I feel this is the most easy implementation. But it will invoke higher memory usage and comm payload. Please let me know if this alternative is preferred.
cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang gcramer23
Test Plan:
Imported from OSS
python test/distributed/_sharded_tensor/test_sharded_tensor.py -v -k test_gather
did not manage to test on oss, but tested in fbcode by reserving on demand gpu
arc patch D31197611
modify the test with 2 gpus as on-demand gpu only has 2 cores (D31227986)
buck test -c fbcode.enable_gpu_sections=true mode/dev-nosan caffe2/test/distributed/_sharded_tensor:sharded_tensor -- test_gather
buck-out/gen/caffe2/test/distributed/_sharded_tensor/sharded_tensor#binary.par test_sharded_tensor.TestShardedTensorChunked.test_gather
{F667213605}
Reviewed By: dagitses, pritamdamania87
Differential Revision: D31197611
Pulled By: dracifer
fbshipit-source-id: cf98b4a2d7838b11b9582eb23f826bb0fa38a7f4