DeepSpeed
956ec6fb - Extending Muon Optimizer Support for ZeRO Stage 3 (#7919)

Commit
10 days ago
Extending Muon Optimizer Support for ZeRO Stage 3 (#7919) Authors: @pengdurice and @PKUWZP Create a separate PR based on #7798 with the same functional diff on a clean signed-off branch to resolve DCO issues. We aim on adding Muon Optimizer to zero stage 3 in this draft PR: - Created a dedicated momentum buffer in zero stage 3 optimizer to save the momentum buffers specifically for Muon Optimizer. - The optimizer states can be dispatched into 3 devices: GPU, CPU and NVME. For GPU and CPU, we just make the new buffers the same device of `self.fp32_partitioned_groups_flat`; when `device == NVME`, we make sure that the momentum buffers can be swapped in and out along with other components in the optimizer states. - The new momentum buffers are also partitioned like `self.fp32_partitioned_groups_flat` to save memory footprint. So, before the muon update, we need to perform `all_gather` on top of each data-parallel group rank. The Muon updates of the parameters are also divided across the data-parallel ranks, and the results are all-gathered once all updates are complete. After the `all_gather`, the momentum buffers are partitioned and flattened again. Next steps: - Explore quantization of momentum buffers for saving memory - Explore using highly optimized Adam / AdamW Optimizers --------- Signed-off-by: PKUWZP <zhipeng.rainbowserie@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Author
Parents
Loading