Improve the documentation of DistributedDataParallel (#42471)
Summary:
Fixes #{issue number}
It's not clear by illustrating 'gradients from each node are averaged' in the documentation of DistributedDataParallel. Many people, including me, have a totally wrong understanding on this part. I add a note into the documentation to make it more straight forward and more user friendly.
Here is some toy code to illustrate my point:
* non-DistributedDataParallel version
```python
import torch
import torch.nn as nn
x = torch.tensor([-1, 2, -3, 4], dtype=torch.float).view(-1, 1)
print("input:", x)
model = nn.Linear(in_features=1, out_features=1, bias=False)
model.weight.data.zero_()
model.weight.data.add_(1.0)
opti = torch.optim.SGD(model.parameters(), lr=0.001)
opti.zero_grad()
y = model(x)
label = torch.zeros(4, 1, dtype=torch.float)
loss = torch.sum((y - label)**2)
loss.backward()
opti.step()
print("grad:", model.weight.grad)
print("updated weight:\n", model.weight)
# OUTPUT
# $ python test.py
# input: tensor([[-1.],
# [ 2.],
# [-3.],
# [ 4.]])
# grad: tensor([[60.]])
# updated weight:
# Parameter containing:
# tensor([[0.9400]], requires_grad=True)
```
* DistributedDataParallel version
```python
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.multiprocessing import Process
def run(rank, size):
x = torch.tensor([-(1 + 2 * rank), 2 + 2 * rank], dtype=torch.float).view(-1, 1)
print("input:", x)
model = nn.Linear(in_features=1, out_features=1, bias=False)
model.weight.data.zero_()
model.weight.data.add_(1.0)
model = torch.nn.parallel.DistributedDataParallel(model)
opti = torch.optim.SGD(model.parameters(), lr=0.001)
opti.zero_grad()
y = model(x)
label = torch.zeros(2, 1, dtype=torch.float)
loss = torch.sum((y.view(-1, 1) - label)**2)
loss.backward()
opti.step()
if rank == 0:
print("grad:", model.module.weight.grad)
print("updated weight:\n", model.module.weight)
def init_process(rank, size, fn, backend="gloo"):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
size = 2
process = []
for rank in range(size):
p = Process(target=init_process, args=(rank, size, run))
p.start()
process.append(p)
for p in process:
p.join()
# OUTPUT
# $ python test_d.py
# input: tensor([[-3.],
# [ 4.]])input: tensor([[-1.],
# [ 2.]])
# grad: tensor([[30.]])
# updated weight:
# Parameter containing:
# tensor([[0.9700]], requires_grad=True)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42471
Reviewed By: glaringlee
Differential Revision: D22923340
Pulled By: mrshenli
fbshipit-source-id: 40b8c8ba63a243f857cd5976badbf7377253ba82