[Model Averaging] Create a post-localSGD communication hook (#61206)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61206
Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD.
In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to run global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages:
1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD.
2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient.
Proposal: https://github.com/pytorch/pytorch/issues/59699
ghstack-source-id: 133371322
Test Plan: buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_ddp_hook_parity_post_localSGD
Reviewed By: pritamdamania87
Differential Revision: D29523292
fbshipit-source-id: 3f215f7150f2917c2781278fad759530c685ea2c