Add ProcessGroupAgent termination detection algorithm (#26984)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26984
closes #26944
In the existing implementation, each worker exits when it sees
no send/recv tasks. However, as we adding support for nested calls,
one RPC could trigger more RPCs in the UDF or in the response
callback. As a result, even if the worker does not see any send/recv
tasks for now, it does not mean there won't be any in the future.
In this commit, we added a counters for all sent and received
messages between each pair of nodes, and then use allgather to collect
those counters, i.e., all workers would have the same view on the
global states. The workers would only exit when all sends are
received and processed.
Test Plan: Imported from OSS
Differential Revision: D17633456
Pulled By: mrshenli
fbshipit-source-id: 813a155d3b2daf2226612eb17f6c698512e9beca