Added communication hook for sharded cases (#83254)
Fixes https://github.com/pytorch/pytorch/issues/79114
An implementation of a FSDP communication hook interface for a sharded strategies:
- Added `reduce_scatter_hook` to default hooks. Note the difference of `reduce_scatter` from `all_reduce`, it requires 2 tensors:`input_gradient` and `output` variables and stores result in `output`, which is further used as a summed gradient shard.
- Adjusted FSDP logic to return `reduce_scatter_hook` as a default communication hook for sharded strategies, `DefaultState` is the same for sharded and non-sharded strategies.
- Adjusted low-precision hooks to work with both `all_reduce` and `reduce_scatter` depending on whether `output` tensor is provided or not.
Test plan:
Added all existing sharded strategies as an input parameters to existing tests.
For`test_default_communication_hook_behaviour` double checked how a linear layer is sharded across workers. This test creates a simple net ``1 X N``, where ``N`` - is the number of workers. For sharded cases, ``N`` parameters are sharded across ``N`` workers. This test checks that after backward, each worker has a proper value in it's chunk of the gradient, or the whole gradient on every worker is equal to an expected value.
Checked that low-precision tests work for sharded cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83254
Approved by: https://github.com/rohan-varma, https://github.com/awgu