pytorch
f204afc2 - Added communication hook for sharded cases (#83254)

Commit
2 years ago
Added communication hook for sharded cases (#83254) Fixes https://github.com/pytorch/pytorch/issues/79114 An implementation of a FSDP communication hook interface for a sharded strategies: - Added `reduce_scatter_hook` to default hooks. Note the difference of `reduce_scatter` from `all_reduce`, it requires 2 tensors:`input_gradient` and `output` variables and stores result in `output`, which is further used as a summed gradient shard. - Adjusted FSDP logic to return `reduce_scatter_hook` as a default communication hook for sharded strategies, `DefaultState` is the same for sharded and non-sharded strategies. - Adjusted low-precision hooks to work with both `all_reduce` and `reduce_scatter` depending on whether `output` tensor is provided or not. Test plan: Added all existing sharded strategies as an input parameters to existing tests. For`test_default_communication_hook_behaviour` double checked how a linear layer is sharded across workers. This test creates a simple net ``1 X N``, where ``N`` - is the number of workers. For sharded cases, ``N`` parameters are sharded across ``N`` workers. This test checks that after backward, each worker has a proper value in it's chunk of the gradient, or the whole gradient on every worker is equal to an expected value. Checked that low-precision tests work for sharded cases. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83254 Approved by: https://github.com/rohan-varma, https://github.com/awgu
Author
Olga Andreeva
Committer
Parents
Loading