Reject Muon optimizer with reduce_scatter in ZeRO-1/2 (#8090)
## Summary
ZeRO-1/2 silently produces incorrect, rank-divergent parameter updates
when the Muon optimizer is used together with `reduce_scatter` (the
default). This adds an explicit error at initialization, mirroring the
existing ZeRO-3 guard, and includes a regression test. Closes #7807.
## Root cause
Muon's Newton-Schulz orthogonalization is a whole-matrix operation: the
rank that updates a parameter must hold that parameter's complete,
fully-reduced gradient matrix, then take its partition slice of the
orthogonalized result.
- `get_flat_partition()` (`deepspeed/runtime/zero/stage_1_and_2.py`)
applies `muon_update()` to each parameter's gradient reshaped to its
full 2-D shape, and only then narrows to this rank's partition.
- With `reduce_scatter=True`, `average_tensor()` reduce-scatters the
gradients: each rank receives the averaged values only for its own
partition slice. For the rest of a parameter whose flattened gradient
crosses a partition boundary, the rank still holds its local,
un-all-reduced gradient.
- So for any cross-partition parameter, no rank holds the full reduced
matrix. `muon_update` orthogonalizes a partly-reduced, rank-divergent
matrix, and each rank silently applies a different, incorrect update.
Parameters that lie wholly inside one partition are unaffected — exactly
matching the report.
ZeRO-3 already guards this exact conflict in
`deepspeed/runtime/zero/stage3.py` (added in #7919):
```python
if self.use_muon and self.reduce_scatter:
raise ValueError("Muon and reduce scatter cannot be used together")
```
ZeRO-1/2 had no equivalent. The existing Muon unit tests pin
`"reduce_scatter": false` everywhere, which implicitly acknowledges the
path is unsupported but never enforces it for users — and since
`reduce_scatter` defaults to `true`, a default Muon + ZeRO-1/2 run is
silently wrong.
## Fix
Mirror the ZeRO-3 guard in ZeRO-1/2: raise the same `ValueError` at
initialization when the optimizer is `MuonWithAuxAdam` and
`reduce_scatter` is enabled. To run Muon under ZeRO-1/2, set
`"reduce_scatter": false` (as the Muon tests already do). The change is
the import plus the guard, with no other behavioral change.
## Verification (2x RTX 4090, torch 2.9.1+cu128, ZeRO stage 1 and 2)
- **Before**: `deepspeed.initialize` with Muon + `reduce_scatter=true`
succeeds silently. With `world_size=2` and a model sized so a 2-D weight
straddles the gradient-partition boundary, that weight's post-step
update diverges by ~0.67 in relative Frobenius norm from the correct
full-gradient result, while wholly-owned weights are unaffected —
confirming the silent cross-partition corruption.
- **After**: the same configuration raises `ValueError: Muon and reduce
scatter cannot be used together` for both ZeRO stage 1 and 2. The
existing Muon tests (which use `reduce_scatter: false`) remain green.
## Notes
This supersedes #7878 and #7808, which aimed at the same issue by trying
to force a full all-reduce for Muon but ended up with a
self-contradictory guard. Aligning ZeRO-1/2 with the merged ZeRO-3
behavior (#7919) keeps the two code paths consistent and turns silent
numerical corruption into a clear, actionable error.
A follow-up PR adds a numerical-correctness regression test for the
supported `reduce_scatter: false` Muon path, since the current Muon
tests only assert that parameters changed.
Closes #7807
cc @PKUWZP @pengdurice (ZeRO-3 Muon guard, #7919) @tohtana
Signed-off-by: whycoming <alwaysxd666@gmail.com>
Co-authored-by: Ma, Guokai <guokai.ma@gmail.com>