pytorch
80dfc974 - [2D] Enable 2D FSDP+TP model.load_state_dict() (#110925)

Commit
1 year ago
[2D] Enable 2D FSDP+TP model.load_state_dict() (#110925) This PR adds a all_gather_dtensor() method to fsdp/_fsdp_extensions.py and the actual implementation in tensor/parallel/fsdp.py. This enables FSDP to load 2D DTensor state_dict into model when calling `model.load_state_dict()`. cc. @fegin Pull Request resolved: https://github.com/pytorch/pytorch/pull/110925 Approved by: https://github.com/fegin ghstack dependencies: #110831, #110846
Author
Committer
Parents
Loading