[FSDP] Stricten `_update_p_data()` in `_summon_full_params()` (#81573)
https://github.com/pytorch/pytorch/blob/385ae8721e327a12be5120115c01f866fcdefa71/torch/distributed/fsdp/fully_sharded_data_parallel.py#L2530-L2542
The `finally` block below should undo what is done above -- namely, pointing the flattened parameter's data to the CPU copy of the unsharded flattened parameter.
https://github.com/pytorch/pytorch/blob/385ae8721e327a12be5120115c01f866fcdefa71/torch/distributed/fsdp/fully_sharded_data_parallel.py#L2558-L2575
(This code snipped shows after the change in the PR.)
This PR makes the conditional in the `finally` match the conditional before (adding the `and (not rank0_only or my_rank == 0)` part). Otherwise, for nonzero ranks when `rank0_only == True`, their flattened parameters' `.data` is unnecessarily updated to itself.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81573
Approved by: https://github.com/rohan-varma