pytorch
607eccb1 - [FSDP] Option to keep grads in lower prec (#85134)

Commit
2 years ago
[FSDP] Option to keep grads in lower prec (#85134) Differential Revision: [D39565189](https://our.internmc.facebook.com/intern/diff/D39565189) Rehash of a similar PR from a month ago that got stale. Adds a config to FSDP MP so that gradients can be kept in lower precision, to support optimizers such as AnyPrecisionOptimizer which would like to keep grads in bf16. To do this, for sharded cases, we cannot simply omit the cast back to the full precision param dtype, otherwise when setting `p.grad = p._saved_grad_shard` in finalize_params, autograd will throw an error indicating that the grad dtype should match the param dtype when it is being set. As a workaround, we re-cast after setting this. Although, this means that for cases that use gradient accumulation, p._saved_grad_shard will be of the reduced dtype because it is set to p.grad in `_prep_grad_for_backward`. As a result, add a check + recast here as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/85134 Approved by: https://github.com/awgu
Author
Committer
Parents
Loading