[FSDP2][BE] Refactored `check_1d_sharded_parity` to use mesh (#121357)
Eventually, we should just have one unified way to check for parity between a `DTensor`-sharded model and a replicated model. This PR is a small refactor to work toward that. One current gap to use this `check_sharded_parity` function for 2D is that FSDP's `(Shard(0), Shard(0))` layout differs from that of the `DTensor` APIs since FSDP shards on dim-0 after TP shards on dim-0.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121357
Approved by: https://github.com/weifengpy
ghstack dependencies: #121360