Make grad point to bucket buffer in DDP to save memory usage (#41954)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41954
Make both variable.grad() and grad in distautograd context point to bucket buffer in DDP to save memory usage.
In this case, grad will be view of bucket buffer tensors, in order to make it compatiable with optimizer.zero_grad(), we
made changes in https://github.com/pytorch/pytorch/pull/41283.
Also be noted that we can not make variable.grad() pointing to bucket buffer during construction time, because we want to
keep grad undefined for unused parameters.
ghstack-source-id: 110260297
Test Plan:
unit tests,
For roberta_base model with ~1GB parameters, peak memory dropped ~1GB (8250MB-7183MB). Per iteration latency (0.982s ->0.909s), 8% speed up
https://www.internalfb.com/intern/fblearner/details/211713882?tab=operator_details
https://www.internalfb.com/intern/fblearner/details/211772923?tab=operator_details
For resnet model with ~97M parameters, peak memory dropped ~100MB (3089MB -> 2988MB). Per iteration latency has no change (0.122s -> 0.123s)
https://www.internalfb.com/intern/fblearner/details/211713577?tab=operator_details
https://www.internalfb.com/intern/fblearner/details/211712582?tab=operator_details
accuracy benchmark is expected as well
https://www.internalfb.com/intern/fblearner/details/213237067?tab=Outputs
Reviewed By: mrshenli
Differential Revision: D22707857
fbshipit-source-id: b5e767cfb34ccb3d067db2735482a86d59aea7a4