[FSDP][2/N] Fix grad zero vs. `None` edge case (#87308)
Some original parameters corresponding to one `FlatParameter` may have `None` gradient while others do not. In that case, the `flat_param.grad` must be non-`None`. However, FSDP should take care to expose the original parameters' gradients regardless. To achieve this, we track a `_is_grad_none` mask over the parameters' gradients.
- `_is_grad_none` is initialized to `False` for all.
- `_is_grad_none[i]` is set to `True` when writing zeros in place of `None` when writing back the `i`th gradient.
- `_is_grad_none[i]` is set to `False` via `_reset_is_grad_none()`, which should be called in the post-backward. See the docstring for details.
- `_is_grad_none[i]` must be `False` in order to set `param.grad` to be a view into `flat_param.grad`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87308
Approved by: https://github.com/zhaojuanmao