[DDP] Uneven inputs: option to throw early (#56755)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56755
Rehash of https://github.com/pytorch/pytorch/pull/47488
Adds a flag to ddp join() context manager that enables throwing a
StopIteration across all ranks when this flag is specified.
To do this, we implement the design in #47250. When running with this flag, we schedule an additional allreduce in the case that a joined rank needs to throw a StopIteration. In non-joined ranks forward pass, we match this allreduce and if at least one rank tells us to throw, we raise a StopIteration.
Tested by modifying existing tests, as well as adding additional tests validating that this works with SyncBatchNorm models and a model with custom collectives in the forward pass.
Currently running perf benchmarks, will post when those are available, but we expect a small (~2%) perf reduction when enabling this feature due to the blocking allreduce. Hence we will only recommend it for models with collective comm.
ghstack-source-id: 127883115
Test Plan: Ci
Reviewed By: SciPioneer
Differential Revision: D27958369
fbshipit-source-id: c26f7d315d95f17bbdc28b4a0561916fcbafb7ca