pytorch
0f6876d7 - [Model Averaging] Create a post-localSGD communication hook (#61206)

Commit
3 years ago
[Model Averaging] Create a post-localSGD communication hook (#61206) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61206 Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD. In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to run global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages: 1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD. 2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient. Proposal: https://github.com/pytorch/pytorch/issues/59699 ghstack-source-id: 133371322 Test Plan: buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_ddp_hook_parity_post_localSGD Reviewed By: pritamdamania87 Differential Revision: D29523292 fbshipit-source-id: 3f215f7150f2917c2781278fad759530c685ea2c
Author
Yi Wang
Parents
Loading