pytorch
5fc209ed - FSDP communication hook interface for NO_SHARD strategy (#79833)

Commit
3 years ago
FSDP communication hook interface for NO_SHARD strategy (#79833) Fixes #79114 An implementation of a FSDP communication hook interface for a NO_SHARD strategy: - `FullyShardedDataParallel.register_comm_hook(self, state: object, hook: callable)` checks current sharding strategy. If it is other that NO_SHARD, raises a runtime error. Otherwise, sets and shares a specified hook and its state with all submodules - When FSDP is ready to communicate a gradient, checks if there is a registered hook, and calls it instead of all_reduce. Additionally, gradient pre and post devision are not performed if a hook is registered. To test the interface, I've implemented a communication hook, that calls for `all_reduce`. A unittest: - checks that is a sharding strategy is anything but NO_SHARD, a runtime error is raised - checks that for a NO_SHARD case, model with registered all_reduce hook and without a hook work the same. - checks for 2 types of FSDP models: with the wrapped first layer and without. (to make sure submodules have a hook registered) Pull Request resolved: https://github.com/pytorch/pytorch/pull/79833 Approved by: https://github.com/rohan-varma, https://github.com/awgu
Author
Olga Andreeva
Committer
Parents
Loading