pytorch
9d731e87 - [Gradient Compression] Explicitly specify the dtype of the error tensor (#50985)

Commit
3 years ago
[Gradient Compression] Explicitly specify the dtype of the error tensor (#50985) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50985 Explicitly specify the dtype of error tensor when it is initialized by zeros. Previously if the dtype of input tensor is FP16, the error tensor is still created in FP32, although later it will be assigned by another FP16 tensor (`input_tensor_cp` - `input_tensor`). This change will make the dtype of error tensor look more clear. Additionally, also explicitly specify the dtype if rank-1 tensor buffer is empty. Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202 ghstack-source-id: 120377786 Test Plan: buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_DistributedDataParallel_powerSGD_ddp_comm_hook Reviewed By: rohan-varma Differential Revision: D26034988 fbshipit-source-id: e0d323d0b77c6a2478cdbe8b31a1946ffd1a07da
Author
Yi Wang
Parents
Loading