pytorch
6feba4bc - Implement scatter primitive for ProcessGroupNCCL (#70029)

Commit
2 years ago
Implement scatter primitive for ProcessGroupNCCL (#70029) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70029 This PR implements NCCL scatter and add scatter to ProcessGroupNCCL. NCCL doesn’t directly provide primitives for scatter, so we need to be implemented on top of NCCL’s send/recv API. 1. In ProcessGroupNCCL.cpp, the inputTensors are first flattened, then outputTensors and inputFlattened are passed by the collective class to scatter() function in nccl.cpp. 2. In nccl.cpp, scatter is implemented using ncclSend/ncclRecv: the root rank uses a for loop to send(distribute) the inputTensors to each rank, then all the ranks receive the inputTensor from the root rank. ghstack-source-id: 147754837 Test Plan: test_scatter_ops test_scatter_stress test_scatter_checks Reviewed By: pritamdamania87 Differential Revision: D33154823 fbshipit-source-id: 4513e7eaf7d47a60eb67da99dc6c2e9a2882f3fd (cherry picked from commit 93201f9d4a87c556110e60ceb93826abd71cf518)
Author
Committer
Parents
Loading