pytorch
1f3aa535 - [reland] use scatter in shard_tensor API (#75991)

Commit
2 years ago
[reland] use scatter in shard_tensor API (#75991) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75991 This is a reland of PR https://github.com/pytorch/pytorch/pull/72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to https://github.com/pytorch/pytorch/pull/75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. ghstack-source-id: 154725305 Test Plan: test_sharded_tensor test_linear test_megatron_prototype test_embedding/embeddingbag Reviewed By: pritamdamania87 Differential Revision: D35726920 fbshipit-source-id: d9bd0e44f47ef5b9e9add0dc66c5fda99e93943a (cherry picked from commit ed11e5d9a9e406e42f2a989aeaaf38ed9fb6b4b6)
Author
Committer
Parents
Loading