[FSDP][optim_state_dict] Fuse allgather for optim_state_dict when use_orig_params is True (#108298)
The original implementation of `_gather_orig_param_state` is naive. It performs one allgather_object and two allgather (if the optimizer is Adam) per FQN. This can be slow and make `_optim_state_dict` become bottleneck.
This PR rewrite the implementation and fuse all the `allgather_object`s into one. As for `allgather`, it is fused based on the information of FlatParameters. So there will be 2N `allgather` where N is the number of FlatParameter and 2 is due to Adam having 2 states per FQN.
One experiment on 8GPU A100 shows that the execution of the gathering is improved to 0.3 seconds from 3 seconds.
Differential Revision: [D48835138](https://our.internmc.facebook.com/intern/diff/D48835138/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108298
Approved by: https://github.com/awgu