pytorch
9a3e16c7 - Add guard for non-default stream in DDP's autograd engine callback (#40115)

Commit
4 years ago
Add guard for non-default stream in DDP's autograd engine callback (#40115) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40115 Closes https://github.com/pytorch/pytorch/issues/37790 Closes https://github.com/pytorch/pytorch/issues/37944 A user may wish to run DDP's forward + backwards step under a non-default CUDA stream such as those created by `with torch.cuda.Stream(stream)`. In this case, the user should be responsible for synchronizing events on this stream with other streams used in the program (per the documentation at https://pytorch.org/docs/stable/notes/cuda.html#cuda-semantics), but currently DDP has a bug which causes DDP under non-default streams to fail. If a user does the following: ``` model = DDP(...) loss = model(inptut).sum() loss.backward() grad = model.module.weight.grad() average = dist.all_reduce(grad) ``` There is a chance that `average` and `grad` will not be equal. This is because the CUDA kernels corresponding to the `all_reduce` call may run before `loss.backward()`'s kernels are finished. Specifically, in DDP we copy the allreduced gradients back to the model parameter gradients in an autograd engine callback, but this callback runs on the default stream. Note that this can also be fixed by the application synchronizing on the current stream, although this should not be expected, since the application is not using the current stream at all. This PR fixes the issue by passing the current stream into DDP's callback. Tested by adding a UT `test_DistributedDataParallel_non_default_stream` that fails without this PR ghstack-source-id: 106481208 Differential Revision: D22073353 fbshipit-source-id: 70da9b44e5f546ff8b6d8c42022ecc846dff033e
Author
Parents
Loading