pytorch
688b7672 - [FSDP] Fix `keep_low_precision_grads=True` for `use_orig_params=True` (#90027)

Commit
2 years ago
[FSDP] Fix `keep_low_precision_grads=True` for `use_orig_params=True` (#90027) For any `flat_param.data = flat_param.to(...)` or `flat_param.grad.data = flat_param.grad.to(...)`, we must also refresh sharded parameter/gradient views, respectively, if the storage changes. For `keep_low_precision_grads=True` and a sharded strategy, we cast the gradient back to the low precision using `.data` to bypass the PyTorch check that a parameter and its gradient have the same dtype. For `use_orig_params=True` before this PR, the gradient would incorrectly still be in full precision, not low precision, since we did not refresh views (this can actually be considered a memory leak since we have two copies of the gradient now, one in low precision and one in full precision). This PR refreshes the views. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90027 Approved by: https://github.com/mrshenli
Author
Committer
Parents
Loading