pytorch
f3e42f15 - [FSDP] Start to generalize modules to ignore for mixed precision (#102010)

Commit
2 years ago
[FSDP] Start to generalize modules to ignore for mixed precision (#102010) The main use case here is that folks would like to ignore layer norm for mixed precision. This can now be enabled with: ``` mp_config = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, _mixed_precision_module_classes_to_ignore=[_BatchNorm, nn.LayerNorm], ) ``` This is done by classes of types in `_mixed_precision_module_classes_to_ignore` being wrapped in their own FSDP unit with mixed preicsion disabled. This is only enabled for auto wrapping. We also add module pre and post hooks to cast / downcast inputs to the appropriate full precision. Differential Revision: [D46079957](https://our.internmc.facebook.com/intern/diff/D46079957/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/102010 Approved by: https://github.com/awgu
Author
Committer
Parents
Loading