[FSDP] Fix a small bug of pre_backward_hook params prefetch (#78851)
Fix a potential small bug in FSDP pre_backward_hook params prefetch.
In `_pre_backward_hook`, `self._need_prefetch_full_params(self.training_state)` is used to decide whether the params of the next backward pass needs to be pre-fetched, and currently it is also used to check whether we want to perform synchronization in the current backward pass before `_rebuild_full_params`.
For some edge cases, using this to check whether to perform synchronization is not current. One example is when `self._my_fsdp_idx_in_graph = 0`, which means this is the last backward pass. In this way, we have `self._need_prefetch_full_params(self.training_state)=False` since there is no backward pass after it, and currently synchronization will not be done before `_rebuild_full_params`.
But the params of this layer is prefetched at the previous layer, thus a synchronization needs to be done.
To fix this, we just needs to check whether to do the synchronization using another flag rather than `self._need_prefetch_full_params(self.training_state)`, and that is what this PR does.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78851
Approved by: https://github.com/zhaojuanmao