[shard] use scatter in shard_parameter API (#72160)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72160
This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.
ghstack-source-id: 151643718
Test Plan:
test_shard_parameter
test_shard_parameter_errors
Reviewed By: pritamdamania87
Differential Revision: D33933419
fbshipit-source-id: c823c5d0066a9fe7451c07cbacb30a3bbd361af4
(cherry picked from commit b1b553e89296e392a69850e93c950734c9d93c96)