Support accumulating DDP grads using a context manager (#21736)
Summary:
The first attempt and more discussions are available in https://github.com/pytorch/pytorch/issues/19577
#### Goal
Allow toggling DDP gradient synchronization across iterations. With this feature, users may accumulate grads in module variables, and only kick off expensive grad synchronize every a few iterations.
#### Concerns
Our first attempt in https://github.com/pytorch/pytorch/issues/19577 tries to do it using a variable or a function. But apaszke made a good point that it will not be error prone, and favors a context manager instead.
#### Proposed Solution
Instead of providing a `accumulate_grads` variable/function/context, we provide a `DistributedDataParallel.no_sync()` context manager. And it does exactly what the name suggests, i.e., disable DDP grad synchronization within the context. Note that `accumulate_grads` means `no_sync` + no optimizer step, where the latter is not controlled by DDP.
It is true that users need to call another `model(input).backward()` after exiting the context, and this is indeed more verbose. But I think it is OK as one major concern in the previous discussion is to prevent users from running into errors without knowing it. This API should reaffirm the expected behavior, and does not mess up with other use cases if accumulating grads is not required..
The application would then look like:
```python
with ddp.no_sync():
for input in inputs:
ddp(input).backward()
ddp(one_more_input).backward()
optimizer.step()
```
chenyangyu1988 myleott
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21736
Differential Revision: D15805215
Pulled By: mrshenli
fbshipit-source-id: 73405797d1e39965c52016af5cf45b15525ce21c