Add new rpc.barrier API (#53423)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53423
closes #40166
This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names.
Example:
```python
import os
import torch.multiprocessing as mp
import torch.distributed.rpc as rpc
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "5678"
world_size = 4
odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2]
even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2]
def worker(i):
print(i)
rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size)
if i % 2:
print(f"start barrier {i}")
rpc.barrier(set(odd_num_workers))
else:
print(f"start barrier {i}")
rpc.barrier(set(even_num_workers))
rpc.shutdown()
print(f"shutdown{i}")
if __name__ == '__main__':
with mp.Pool(processes=world_size) as pool:
pool.map(worker, range(world_size))
```
Test Plan: Imported from OSS
Reviewed By: rohan-varma
Differential Revision: D27737145
Pulled By: H-Huang
fbshipit-source-id: 369196bc62446f506d1fb6a3fa5bebcb0b09da9f