pytorch
7ee68363 - Add new rpc.barrier API (#53423)

Commit
3 years ago
Add new rpc.barrier API (#53423) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53423 closes #40166 This change exposes a new API, rpc.barrier() which blocks the main processes of all workers running RPC until the whole group completes this function. Optionally rpc.barrier can take in a set of worker_names and only synchronize across those worker names. Example: ```python import os import torch.multiprocessing as mp import torch.distributed.rpc as rpc os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "5678" world_size = 4 odd_num_workers = [f"worker{i}" for i in range(world_size) if i % 2] even_num_workers = [f"worker{i}" for i in range(world_size) if not i % 2] def worker(i): print(i) rpc.init_rpc(f"worker{i}", rank=i, world_size=world_size) if i % 2: print(f"start barrier {i}") rpc.barrier(set(odd_num_workers)) else: print(f"start barrier {i}") rpc.barrier(set(even_num_workers)) rpc.shutdown() print(f"shutdown{i}") if __name__ == '__main__': with mp.Pool(processes=world_size) as pool: pool.map(worker, range(world_size)) ``` Test Plan: Imported from OSS Reviewed By: rohan-varma Differential Revision: D27737145 Pulled By: H-Huang fbshipit-source-id: 369196bc62446f506d1fb6a3fa5bebcb0b09da9f
Author
Parents
Loading