Fix C++ data parallel (#20910)
Summary:
Fixes #19540
CC nmerrill67
C++ data parallel was using Module.clone() to create module replicas on every destination device. However, clone() does not set up gradient edges to point from replicas to the original module. As a result, the gradient will not be aggregated into the original module. This commit fixes the the problem by manually setting gradient edges from every parameter X in every replica to the same parameter X in the original module.
## Failed Attempt
Initially I tried implementing what we did in `replicate.py`, which
1. create module replicas
2. use Python `Broadcast` autograd function to broadcast every parameter in the original module to all destination devices.
3. assign the broadcast result params to module replicas' `_parameters` dict.
This works in Python because derived module member field params (e.g., `Linear.weight`) and base module `_parameters` (e.g., `Linear._parameters['weight']`) are referencing the same parameter instance. Assigning one of them will apply to both. However, in C++, even though I can modify Module's `parameters_ `values and gradient edges to point to the broadcast source, I cannot touch the weight and bias member fields in Linear, because replicate cannot (and should not) add special-case handlers to every different module. (See `Linear` [.h](https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/modules/linear.h), [.cpp](https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/src/nn/modules/linear.cpp)) Although they initially point to the same `TensorImpl` instance, after assigning to `Module.parameters_['weight']`, it will be different from `Linear.weight`.
## Solution Options
gchanan and I had several discussions on this issue and figured two solutions to this problem.
### Option One [implemented in this PR]
Replicate the module in two steps:
1. call `Module.clone()` to create a module replica on every destination device.
2. manually setting gradient edges from every parameter in every replica to the same parameter in the original module.
* Pro: Does not need to change any existing module, and relatively easier to implement
* Con: It is a little hackish.
### Options Two
Implement a `Replicatable` class (similar to `Cloneable`), and make it a friend class of `Module`. For more details see `Note [Replicating Modules]` in the code change.
* Pro: Maybe this aligns more with our existing approach implemented in `Cloneable`?
* Con: Require changes to every existing module.
I am inclined to go with option one, because `replicate` will only be used on data parallel. I feel it is too big an overkill if we have to change all existing module implementations due to a data parallel requirement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20910
Differential Revision: D15556426
Pulled By: mrshenli
fbshipit-source-id: aa836290ec657b32742e2bea80bd0ac2404ef3b0