pytorch
cdde899a - [FSDP][optim_state_dict] Fuse allgather for optim_state_dict when use_orig_params is True (#108298)

Commit
1 year ago
[FSDP][optim_state_dict] Fuse allgather for optim_state_dict when use_orig_params is True (#108298) The original implementation of `_gather_orig_param_state` is naive. It performs one allgather_object and two allgather (if the optimizer is Adam) per FQN. This can be slow and make `_optim_state_dict` become bottleneck. This PR rewrite the implementation and fuse all the `allgather_object`s into one. As for `allgather`, it is fused based on the information of FlatParameters. So there will be 2N `allgather` where N is the number of FlatParameter and 2 is due to Adam having 2 states per FQN. One experiment on 8GPU A100 shows that the execution of the gathering is improved to 0.3 seconds from 3 seconds. Differential Revision: [D48835138](https://our.internmc.facebook.com/intern/diff/D48835138/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/108298 Approved by: https://github.com/awgu
Author
Committer
Parents
Loading