pytorch
272cf29e - [FSDP2][BE] Refactored `check_1d_sharded_parity` to use mesh (#121357)

Commit
315 days ago
[FSDP2][BE] Refactored `check_1d_sharded_parity` to use mesh (#121357) Eventually, we should just have one unified way to check for parity between a `DTensor`-sharded model and a replicated model. This PR is a small refactor to work toward that. One current gap to use this `check_sharded_parity` function for 2D is that FSDP's `(Shard(0), Shard(0))` layout differs from that of the `DTensor` APIs since FSDP shards on dim-0 after TP shards on dim-0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/121357 Approved by: https://github.com/weifengpy ghstack dependencies: #121360
Author
Committer
Parents
Loading