pytorch
1133682c - [FSDP][2/N] Fix grad zero vs. `None` edge case (#87308)

Commit
2 years ago
[FSDP][2/N] Fix grad zero vs. `None` edge case (#87308) Some original parameters corresponding to one `FlatParameter` may have `None` gradient while others do not. In that case, the `flat_param.grad` must be non-`None`. However, FSDP should take care to expose the original parameters' gradients regardless. To achieve this, we track a `_is_grad_none` mask over the parameters' gradients. - `_is_grad_none` is initialized to `False` for all. - `_is_grad_none[i]` is set to `True` when writing zeros in place of `None` when writing back the `i`th gradient. - `_is_grad_none[i]` is set to `False` via `_reset_is_grad_none()`, which should be called in the post-backward. See the docstring for details. - `_is_grad_none[i]` must be `False` in order to set `param.grad` to be a view into `flat_param.grad`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87308 Approved by: https://github.com/zhaojuanmao
Author
Committer
Parents
Loading