Add numerical-correctness test for Muon under ZeRO-1/2 (#8091)
## Summary
Adds a numerical-correctness regression test for the Muon optimizer
under ZeRO-1/2. The existing Muon tests only assert that parameters
changed, which cannot detect a wrong-but-nonzero update — exactly the
failure mode of #7807, where `reduce_scatter` fed Newton-Schulz
orthogonalization a partition slice instead of the full DP-averaged
gradient. This complements the guard in #8090 by verifying the supported
`reduce_scatter: false` path is actually numerically correct.
## What the test does
`TestMuonZero12NumericalCorrectness` (in
`tests/unit/ops/muon/test_muon.py`), `world_size=2`, parametrized over
ZeRO stage `[1, 2]` and `ns_method ['gram', 'standard']`:
1. Builds a model sized so a 2-D weight's flattened gradient straddles
the rank-0/rank-1 partition boundary, and asserts this from the
**actual** flattened ZeRO partition (`optimizer.bit16_groups` /
`bit16_groups_flat`, accounting for alignment padding) — the exact case
#7807 corrupts.
2. Runs one step on the supported `reduce_scatter: false` path with
`gradient_clipping=0` and `loss_scale=1`, so the applied master-weight
update is exactly `-lr * muon_update(grad)`.
3. Compares that applied update against an independent reference that
applies the real `muon_update` to the full DP-averaged gradient (using
the library function, so Newton-Schulz rounding cancels), via relative
Frobenius error.
A correct update differs from the reference by only a few percent; the
partition-then-orthogonalize bug diverges by O(1) on the cross-partition
weight, so the assertion uses a 0.40 threshold that cleanly separates
the two.
## Verification (2x RTX 4090, torch 2.9.1+cu128, ZeRO stage 1 and 2)
Relative error of the applied Muon update vs the full-gradient
reference:
| ns_method | correct path (max over weights) | buggy path
(cross-partition weight) |
|---|---|---|
| gram (default) | 0.068 | 0.603 |
| standard | 0.216 | 0.673 |
The cross-partition weight is the only one affected; wholly-owned
weights are identical on both paths. With `reduce_scatter: false` the
test passes for both stages and both ns_methods; injecting the bug
(`reduce_scatter: true`, pre-guard) makes the cross-partition assertion
fail by a wide margin — i.e. this test would have caught #7807.
## Notes
Follow-up to #8090 (which adds the guard and closes #7807). Kept in the
existing `test_muon.py`. Requires >=2 GPUs (fp16).
Refs #7807
cc @PKUWZP @pengdurice (ZeRO-3 Muon guard, #7919) @tohtana
Signed-off-by: whycoming <alwaysxd666@gmail.com>