[reland] use scatter in shard_tensor API (#75991)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75991
This is a reland of PR https://github.com/pytorch/pytorch/pull/72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to https://github.com/pytorch/pytorch/pull/75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.
ghstack-source-id: 154725305
Test Plan:
test_sharded_tensor
test_linear
test_megatron_prototype
test_embedding/embeddingbag
Reviewed By: pritamdamania87
Differential Revision: D35726920
fbshipit-source-id: d9bd0e44f47ef5b9e9add0dc66c5fda99e93943a
(cherry picked from commit ed11e5d9a9e406e42f2a989aeaaf38ed9fb6b4b6)