pytorch
22e8a61d - Implement coalesced reduce_scatter_tensor (#103561)

Commit
1 year ago
Implement coalesced reduce_scatter_tensor (#103561) Map of #101157. This PR adds support for coalesced `reduce_scatter_tensor` calls in the following syntax: Sync communication style: ``` with dist._coalescing_manager(): for i in range(num_coll): dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i]) ``` Async communication style: ``` with dist._coalescing_manager(async_ops=True) as cm: for i in range(num_coll): dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i]) # do a bunch of other things cm.wait() # do things that depend on the reduce-scatters' results ``` Each `reduce_scatter_tensor` call can be independent in terms of their data and buffer locations. But could be executed in parallel by supported backends (like NCCL). Pull Request resolved: https://github.com/pytorch/pytorch/pull/103561 Approved by: https://github.com/fegin
Author
Committer
Parents
Loading