pytorch
f7d68284 - Adding fsdp fp16 and bf16 hooks (#80557)

Commit
2 years ago
Adding fsdp fp16 and bf16 hooks (#80557) Recently, `register_comm_hook` was introduced to `FSDP`, which at the moment supports only `NO_SHARD` strategy and has a default `all_reduce` hook implemented. This PR adds two lower precision hooks to an existing default hook. I've also made slight adjustments to existing implementation of an `all_reduce` hook including: - `AllReduceState` -> `DefaultState` , motivation: `AllReduceState` is not specific to `all_reduce`. Gradients' pre- and post-division factors are also useful for other hooks, that require pre- and post-division, e.g. fp16_hook and bf16_hook. - I've put all 3 hooks into `default_hooks.py` Additionally, `FSDP` supports `MixedPrecision` and, theoretically, it is possible to specify `MixedPrecision` for gradients and attach a lower precision hook to the model. To avoid double-casting, I've added a couple of checks to `fully_sharded_data_parallel`, i.e. casting to precision and back is performed by a lower precision hook only. I think, as a next step, it would be nice to ensure that user can't have both lower precision hook and `MixedPrecision(reduce_dtype=<precision>)` specified, but I am happy to discuss this and adjust current implementation. As a test, I create two models: one with a lower precision hook and one with a `MixedPrecision(reduce_dtype=<precision>)` specified, perform one forward/backward and optimizer step and compare gradients. Pull Request resolved: https://github.com/pytorch/pytorch/pull/80557 Approved by: https://github.com/rohan-varma
Author
Olga Andreeva
Committer
Parents
Loading