DDP should not set grad for globally unused params (#28883)
Summary:
https://github.com/pytorch/pytorch/issues/28294 DDP should not set grad for globally unused parameters
DDP currently computes the param to bucket mapping upfront, and allreduce grads for all params in every iteration. Even if params are unused, it will just set grad to zero. With such behavior, optimizer cannot tell if a param indeed has a zero grad or it is not used in the current iteration. This could trigger convergence problems for optimizers with weight decay and momentum such as SGD. However, DDP cannot simply set grad to None for local unused parameters, as local unused parameters might be used in other processes, and hence we still need to allreduce its grad. Instead DDP should figure out the globally unused parameters and skip touching their grad in the end of backward.
Implementation summary:
* Add locally used parameter map for each model replica.
* Mark the locally unused parameters in the end of forward and then reduce to get the globally unused parameters.
* In the end of backward skip touching grad for those globally unused parameters.
* Add a unit test test_global_local_unused_params_grad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28883
Differential Revision: D18491530
Pulled By: mrshenli
fbshipit-source-id: 24e9b5f20df86c34ddbf9c7106250fd6ce186699