pytorch
4ee13a59 - [FSDP][1/N] Update `summon_full_params(with_grads)` `None` gradient (#87314)

Commit
2 years ago
[FSDP][1/N] Update `summon_full_params(with_grads)` `None` gradient (#87314) This PR changes `summon_full_params(with_grads=True)`'s behavior to be such that if all ranks have `flat_param.grad = None`, then the original parameters will correctly have `orig_param.grad = None`. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is `None` on all of the containing ranks, but not all ranks' `flat_param.grad = None`, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87314 Approved by: https://github.com/zhaojuanmao
Author
Committer
Parents
Loading