pytorch
793a999c - Hybrid Sharded Data Parallel (#89915)

Commit
2 years ago
Hybrid Sharded Data Parallel (#89915) Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - @awgu 's excellent prototype: https://github.com/awgu/pytorch/commit/5ad3a16d486484c9ab4445126e50655eb19d62ca - @liangluofb For ideation, feedback, and initial implementation and experimentation Pull Request resolved: https://github.com/pytorch/pytorch/pull/89915 Approved by: https://github.com/awgu
Author
Committer
Parents
Loading