pytorch
b66cedd9 - [FSDP] Fix `use_orig_params=True` + `no_sync()` (#90546)

Commit
1 year ago
[FSDP] Fix `use_orig_params=True` + `no_sync()` (#90546) `no_sync()` introduces a separate case where a `FlatParameter` maintains an _unsharded_ gradient, instead of a _sharded_ one. This PR fixes `no_sync()` with `use_orig_params=True` by dealing with this separate case. The existing `use_orig_params=False` already bypasses the built-in parameter/gradient size check, where the `flat_param` is sharded, while the `flat_param.grad` is unsharded. For `use_orig_params=True`, we need to use the same `.data` hack to side step the size check that we used to side step the dtype check for `keep_low_precision_grads=True`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90546 Approved by: https://github.com/rohan-varma
Author
Committer
Parents
Loading