[Reland2][DDP] Merge work and future_work in reducer (#59574)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59574
Remove `work` attribute from Reducer class in favor of `future_work`.
Additionally, remove `copy_grad_to_bucket` method since now it's only one-line implementation, and created a new C++ comm hook called `_AllReduceCommHookWithDivFactor` to replace allreduce and also support handling uneven input.
1) Compared with the reverted https://github.com/pytorch/pytorch/pull/58937, updated `_AllReduceCommHookWithDivFactor` in `default_comm_hooks.cpp` to apply division first and hence avoid FP16 overflow.
2) Compared with the reverted https://github.com/pytorch/pytorch/pull/59520, disabled `test_DistributedDataParallel_non_default_stream` on AMD, because now applying division first hurts the gradient averaging accuracy on AMD.
See [07:48:26]:
https://ci.pytorch.org/jenkins/job/pytorch-builds/job/pytorch-linux-bionic-rocm4.2-py3.6-test1/1129/console
#Original PR Issue: https://github.com/pytorch/pytorch/issues/41266
ghstack-source-id: 130752393
Test Plan:
buck test caffe2/test/distributed:distributed_gloo_fork -- test_accumulate_gradients_no_sync
buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_accumulate_gradients_no_sync
buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_ddp_grad_div_uneven_inputs
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_fp16
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_fp16_grad_is_view
buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_DistributedDataParallel_non_default_stream
Reviewed By: rohan-varma
Differential Revision: D28940800
fbshipit-source-id: 1ba727ac951ebc1e7875dc1a1be8108a2c8d9462