pytorch
6c136c33 - [2D] Enable 2D DTensor state_dict for FSDP + TP (#110846)

Commit
1 year ago
[2D] Enable 2D DTensor state_dict for FSDP + TP (#110846) This PR adds a `chunk_dtensor()` method to fsdp/_fsdp_extensions.py and the actual implementation of `chunk_dtensor()` in tensor/parallel/fsdp.py. This enables FSDP to return 2D DTensor state_dict when composing FSDP with TP. cc. @fegin Pull Request resolved: https://github.com/pytorch/pytorch/pull/110846 Approved by: https://github.com/fegin, https://github.com/wanchaol ghstack dependencies: #110831
Author
Committer
Parents
Loading