Implement scatter primitive for ProcessGroupNCCL (#70029)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70029
This PR implements NCCL scatter and add scatter to ProcessGroupNCCL.
NCCL doesn’t directly provide primitives for scatter, so we need to be implemented on top of NCCL’s send/recv API.
1. In ProcessGroupNCCL.cpp, the inputTensors are first flattened, then outputTensors and inputFlattened are passed by the collective class to scatter() function in nccl.cpp.
2. In nccl.cpp, scatter is implemented using ncclSend/ncclRecv: the root rank uses a for loop to send(distribute) the inputTensors to each rank, then all the ranks receive the inputTensor from the root rank.
ghstack-source-id: 147754837
Test Plan:
test_scatter_ops
test_scatter_stress
test_scatter_checks
Reviewed By: pritamdamania87
Differential Revision: D33154823
fbshipit-source-id: 4513e7eaf7d47a60eb67da99dc6c2e9a2882f3fd
(cherry picked from commit 93201f9d4a87c556110e60ceb93826abd71cf518)