pytorch
57feb354 - Refactor non-joined process computation (#61555)

Commit
3 years ago
Refactor non-joined process computation (#61555) Summary: **Overview:** This refactors the computation on non-joined processes relating to the join context manager. The concept was inspired by a comment from pritamdamania. **Changes:** This introduces a `_Joinable` abstract base class, which requires a `_join_hook()` method and `_join_device()` and `_join_process_group()` property methods. Any class that we want to be compatible with the generic join context manager should inherit from `_Joinable` and implement `_join_hook()`, `_join_device()`, and `_join_process_group()`. (The `device` and `process_group` information has been moved from `_JoinHook` to `_Joinable`.) The generic join context manager now takes in a `List[_Joinable]` instead of `List[_JoinHook]`. The motivation for this is that previously, by passing the `_JoinHook`s into the context manager, the class providing a `_JoinHook` can modify the context manager's behavior, but the context manager cannot modify the class's behavior. This is solved by giving the context manager a reference to the class's instance. This implementation reserves the field `_join_config` in every `_Joinable` to store a `_JoinConfig` instance, which holds all dynamic fields needed from the `_Joinable` for the join context manager: `enable`, `throw_on_early_termination`, and `is_first_joinable`. ("dynamic" here means that for a given `_Joinable` instance, the values for those fields may change across different join context usages.) In particular, these fields are needed to implement a method `notify_join_context()`, which encapsulates the computation performed on non-joined processes relating to the join context manager --- (1) the all-reduce to indicate that the process has not yet joined and (2) the all-reduce to check whether to throw an exception if `throw_on_uneven_inputs=True`. The idea is that every `_Joinable` class only needs to make a call to `notify_join_context()` before its per-iteration collective communications; it is a simple one-line addition. Only the first `_Joinable` instance passed into the context manager actually performs the collective communications in `notify_join_context()`. In that case, the method returns an async work handle for the initial all-reduce indicating that the process not yet joined. Otherwise, the method returns `None`. This conditional logic is handled internally without additional input from the user. **New API:** Now, the example usage would look like: ``` ddp_model = DistributedDataParallel(...) zero_optim = ZeroRedundancyOptimizer(ddp_model.parameters(), ...) with _Join([ddp_model, zero_optim]): ... ``` Any arguments meant for a join hook (e.g. `divide_by_initial_world_size`) must be specified as keyword arguments. For example: ``` with _Join([ddp_model, zero_optim], divide_by_initial_world_size=False): ... ``` They will be forwarded to every `_join_hook()` function via `**kwargs`. This creates a clear separation between the variables needed by the context manager (`enable` and `throw_on_early_termination`) and those needed by the `_Joinable` class (e.g. `divide_by_initial_world_size`). **Recap:** After this change, the relevant information to use the generic join context manager looks like the following (omitting prefix `_` from names): - Suppose we have a class `C` (e.g. `DistributedDataParallel`) that we want to be able to use the `Join` context. - We make `C` inherit from `Joinable` and implement `join_hook() -> JoinHook`, `join_device()`, and `join_process_group()`. - To implement `join_hook()`, we define a `CJoinHook` class inheriting from `JoinHook` and implement `main_hook()` and `post_hook()` as needed. - We locate a place before `C`'s per-iteration collective communications and add a call to `Join.notify_join_context()`. - We call `Joinable.__init__(self)` in `C`'s constructor. - The `C.join_config` field will be used internally by the context manager. This does not affect `C`'s serializability. - Run time arguments for `C`'s join hook can be passed in as keyword arguments to the context manager: `with Join([C()], arg1=..., arg2=...):`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61555 Test Plan: I ran the existing DDP join tests: ``` touch /tmp/barrier && TEMP_DIR="/tmp" BACKEND="nccl" WORLD_SIZE="2" gpurun python test/distributed/test_distributed_fork.py -- TestDistBackendWithFork.test_ddp_uneven_inputs TestDistBackendWithFork.test_ddp_uneven_inputs_stop_iteration_sync_bn TestDistBackendWithFork.test_ddp_grad_div_uneven_inputs TestDistBackendWithFork.test_ddp_uneven_input_join_disable TestDistBackendWithFork.test_ddp_uneven_input_exception ``` I ran the ZeRO join tests: ``` gpurun4 python test/distributed/optim/test_zero_redundancy_optimizer.py TestZeroRedundancyOptimizerDistributed.test_zero_join_gpu TestZeroRedundancyOptimizerDistributed.test_zero_join_cpu ``` Reviewed By: zou3519 Differential Revision: D29690359 Pulled By: andwgu fbshipit-source-id: 2950f78de755eb5fb13b95b803dd7c705879a9c7
Author
Parents
Loading