Extending Muon Optimizer Support for ZeRO Stage 3 (#7919)
Authors: @pengdurice and @PKUWZP
Create a separate PR based on #7798 with the same functional diff on a
clean signed-off branch to resolve DCO issues.
We aim on adding Muon Optimizer to zero stage 3 in this draft PR:
- Created a dedicated momentum buffer in zero stage 3 optimizer to save
the momentum buffers specifically for Muon Optimizer.
- The optimizer states can be dispatched into 3 devices: GPU, CPU and
NVME. For GPU and CPU, we just make the new buffers the same device of
`self.fp32_partitioned_groups_flat`; when `device == NVME`, we make sure
that the momentum buffers can be swapped in and out along with other
components in the optimizer states.
- The new momentum buffers are also partitioned like
`self.fp32_partitioned_groups_flat` to save memory footprint. So, before
the muon update, we need to perform `all_gather` on top of each
data-parallel group rank. The Muon updates of the parameters are also
divided across the data-parallel ranks, and the results are all-gathered
once all updates are complete. After the `all_gather`, the momentum
buffers are partitioned and flattened again.
Next steps:
- Explore quantization of momentum buffers for saving memory
- Explore using highly optimized Adam / AdamW Optimizers
---------
Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>