Fix FSDP not all outputs used in loss (#83195)
There are a couple issues / assumptions within FSDP today that this PR attempts to fix:
- In wait_for_post_backward, we assume that if a param required grad, its post backward was called, but this is not true, i.e. if its output did not participate in grad computation, it would not have called post backward. To fix this we simply removed those assertions.
- There is a deeper issue where in `_finalize_params`, we could end up assigning a grad of the sharded shape to an unsharded parameter gradient field, which would raise a shape error. This can happen for example if a parameter's usage transitions from used --> unused. In this case, when the parameter was used, it would have had a gradient, then user could have possibly called `zero_grad()` and p.grad would not be `None`. This in `_prep_grad_for_backward`, we would assign a `_saved_grad_shard` to this gradient field which would be the sharded shape. In `_finalize_param`, our parameter would be unsharded (since post_backward was not called), but we'd try to assign, raising the shape issue. This issue is fixed by checking `_post_backward_called`. If this is False, we simply skip the assignment because there is no new gradient to update.
- A final issue as mentioned above is that if post_backward is not called, we never reshard the full param. This is fixed by checking if we haven't resharded (basically if post_backward_called == False), and if so, performing a reshard.
A few things to note:
- This logic may have to be revisited when non-recursive wrapping lands as there are multiple FlatParams per FSDP unit
- This logic may not work when post_backward_hook fires but p.grad is None, i.e. the short-circuiting here: https://github.com/pytorch/pytorch/blob/f534b2c627da65bbee7ccc8f7e054da0ba48eb79/torch/distributed/fsdp/fully_sharded_data_parallel.py#L2884. As a quick fix, we could just move `_post_backward_called` flag change to after this, or just perform a reshard before returning early. I am not sure how to repro a case where p.grad == None but we call the post-backward hook, https://github.com/pytorch/pytorch/issues/83197 might be a possibility, but I think it is fine to not support this yet.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83195
Approved by: https://github.com/awgu