pytorch
57f72b84 - [DDP] Uneven inputs: option to throw early (#56755)

Commit
3 years ago
[DDP] Uneven inputs: option to throw early (#56755) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56755 Rehash of https://github.com/pytorch/pytorch/pull/47488 Adds a flag to ddp join() context manager that enables throwing a StopIteration across all ranks when this flag is specified. To do this, we implement the design in #47250. When running with this flag, we schedule an additional allreduce in the case that a joined rank needs to throw a StopIteration. In non-joined ranks forward pass, we match this allreduce and if at least one rank tells us to throw, we raise a StopIteration. Tested by modifying existing tests, as well as adding additional tests validating that this works with SyncBatchNorm models and a model with custom collectives in the forward pass. Currently running perf benchmarks, will post when those are available, but we expect a small (~2%) perf reduction when enabling this feature due to the blocking allreduce. Hence we will only recommend it for models with collective comm. ghstack-source-id: 127883115 Test Plan: Ci Reviewed By: SciPioneer Differential Revision: D27958369 fbshipit-source-id: c26f7d315d95f17bbdc28b4a0561916fcbafb7ca
Author
Parents
Loading