[pytorch][PR][Gradient Compression] Reduce the peak memory of fp16 compression provided by ddp comm hook (#46078)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46078
The peak memory usage of ddp comm hook has increased due to an extra copy of gradient tensors. To reduce the memory usage, decompress the fp16 tensor in place of the tensor stored in the the gradient bucket.
#Closes: https://github.com/pytorch/pytorch/issues/45968
ghstack-source-id: 113996453
Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_accumulate_gradients_no_sync_allreduce_hook
Also verified the decrease in memory consumption with some toy modeling exmaples.
Reviewed By: pritamdamania87
Differential Revision: D24178118
fbshipit-source-id: 453d0b52930809bd836172936b77abd69610237a