DeepSpeed
0ecb987d - Add numerical-correctness test for Muon under ZeRO-1/2 (#8091)

Commit
5 days ago
Add numerical-correctness test for Muon under ZeRO-1/2 (#8091) ## Summary Adds a numerical-correctness regression test for the Muon optimizer under ZeRO-1/2. The existing Muon tests only assert that parameters changed, which cannot detect a wrong-but-nonzero update — exactly the failure mode of #7807, where `reduce_scatter` fed Newton-Schulz orthogonalization a partition slice instead of the full DP-averaged gradient. This complements the guard in #8090 by verifying the supported `reduce_scatter: false` path is actually numerically correct. ## What the test does `TestMuonZero12NumericalCorrectness` (in `tests/unit/ops/muon/test_muon.py`), `world_size=2`, parametrized over ZeRO stage `[1, 2]` and `ns_method ['gram', 'standard']`: 1. Builds a model sized so a 2-D weight's flattened gradient straddles the rank-0/rank-1 partition boundary, and asserts this from the **actual** flattened ZeRO partition (`optimizer.bit16_groups` / `bit16_groups_flat`, accounting for alignment padding) — the exact case #7807 corrupts. 2. Runs one step on the supported `reduce_scatter: false` path with `gradient_clipping=0` and `loss_scale=1`, so the applied master-weight update is exactly `-lr * muon_update(grad)`. 3. Compares that applied update against an independent reference that applies the real `muon_update` to the full DP-averaged gradient (using the library function, so Newton-Schulz rounding cancels), via relative Frobenius error. A correct update differs from the reference by only a few percent; the partition-then-orthogonalize bug diverges by O(1) on the cross-partition weight, so the assertion uses a 0.40 threshold that cleanly separates the two. ## Verification (2x RTX 4090, torch 2.9.1+cu128, ZeRO stage 1 and 2) Relative error of the applied Muon update vs the full-gradient reference: | ns_method | correct path (max over weights) | buggy path (cross-partition weight) | |---|---|---| | gram (default) | 0.068 | 0.603 | | standard | 0.216 | 0.673 | The cross-partition weight is the only one affected; wholly-owned weights are identical on both paths. With `reduce_scatter: false` the test passes for both stages and both ns_methods; injecting the bug (`reduce_scatter: true`, pre-guard) makes the cross-partition assertion fail by a wide margin — i.e. this test would have caught #7807. ## Notes Follow-up to #8090 (which adds the guard and closes #7807). Kept in the existing `test_muon.py`. Requires >=2 GPUs (fp16). Refs #7807 cc @PKUWZP @pengdurice (ZeRO-3 Muon guard, #7919) @tohtana Signed-off-by: whycoming <alwaysxd666@gmail.com>
Author
Parents
Loading