Fix FSDP when not all outputs get gradient in backward (#80245)
In some use cases, FSDP runs into an issue where a training state assert in `_wait_for_post_backward` erroneously fires. Digging into the root cause, this is because `_post_backward_hook` which sets the module's training state to backward_post is never actually called, since no param in that module had gradient computed for it. Similar to DDP, this can happen when not all module outputs are used in loss computation, or module did not participate in forward at all.
Fix this by tracking a variable `_post_backward_called` to track whether the hook is actually called or not.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80245
Approved by: https://github.com/awgu