pytorch
7ebd816a - Switch DTensor to use funcol::all_reduce. (#95804)

Commit
1 year ago
Switch DTensor to use funcol::all_reduce. (#95804) This is relanding the troubling part of #95009 that caused a regression. BC: This changes the signature and semantics of DeviceMesh::all_reduce. DeviceMesh::all_reduce now uses a functional collective under the hood which makes it more easily traceable. You no longer need to use CommTensor to get a trace. all_reduce now is async only and uses AsyncCollectiveTensor to ensure proper stream synchronization. Signature changed: removed async_op param and changes return type from Optional[Work] to torch.Tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95804 Approved by: https://github.com/fegin
Author
Rodrigo Kumpera
Committer
Parents
Loading