pytorch
d04e3c99 - [FSDP] Fix input grad propagation when using param mixed precision (#90921)

Commit
2 years ago
[FSDP] Fix input grad propagation when using param mixed precision (#90921) For parameter mixed precision, we cast the inputs to the low precision parameter dtype. If the input has tensors that require gradient, then we must cast them in place in order for them to receive a gradient. The cast should be tracked by autograd (e.g. with `grad_fn` equal to `ToCopyBackward0`). This removes the `torch.no_grad` context when calling `_apply_to_tensors`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90921 Approved by: https://github.com/mrshenli, https://github.com/rohan-varma
Author
Committer
Parents
Loading