pytorch
fc429512 - [FSDP] Clean up `FlatParamHandle` dtypes, post-backward hook (#90660)

Commit
2 years ago
[FSDP] Clean up `FlatParamHandle` dtypes, post-backward hook (#90660) This PR reworks the internal handling of parameter and gradient reduction mixed precision, cleans up the post-backward hook logic, and adds some minor changes to the communication hooks. **Overview** This PR addresses everything in https://github.com/pytorch/pytorch/issues/90657 except renaming `keep_low_precision_grads` to `keep_grads_in_reduce_dtype` since that is BC breaking. I recommend reading the issue before preceding. For `MixedPrecision(param_dtype, reduce_dtype, ...)`, the exact rule for parameter and gradient reduction mixed precision that we are following is: > If `param_dtype is not None` and `reduce_dtype is None`, then we infer `reduce_dtype = param_dtype`. Otherwise, we take `param_dtype` and `reduce_dtype` as is. This PR enforces that, at the `FlatParamHandle` level, `handle._config.fwd_bwd_param_dtype` and `handle._config.reduce_dtype` are never `None`. The way to check if mixed precision is enabled is to compare against the original parameter dtype, which is now stored in `handle._orig_param_dtype`. It is no longer to check against `None`. This avoids ambiguous cases such as when the user passes `MixedPrecision(param_dtype=torch.float32)`. In that case, our existing implementation mistakenly thinks that parameter mixed precision is enabled and either relies on no-ops silently or errors (such as one case reported by MosaicML). **Additional Details** - We remove `FullyShardedDataParallel._mixed_precision_enabled_for_params`, `FullyShardedDataParallel._mixed_precision_enabled_for_reduce`, and `FullyShardedDataParallel._mixed_precision_keep_low_precision_grads` since they are not used. - The unit test `test_meta_device_with_mixed_precision()` exercises a tricky edge case with meta device initialization, `apply()` (calling into `summon_full_params()`), and `param_dtype=torch.float32` for a nested wrapping case, where each nested instance has parameters. - We include some minor fixes/improvements to the communication hook implementation. **Follow-Ups** - We should get rid of `HandleConfig` and store its fields as attributes on `FlatParamHandle` directly. - Rename `keep_low_precision_grads` to `keep_grads_in_reduce_dtype`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90660 Approved by: https://github.com/zhaojuanmao
Author
Committer
Parents
Loading