DeepSpeed
9d632f10 - Fix zero 1 and 2 CPU-offloaded gradient norm (#7967)

Commit
21 days ago
Fix zero 1 and 2 CPU-offloaded gradient norm (#7967) This PR fixes a bug introduced in [#6550](https://github.com/deepspeedai/DeepSpeed/pull/6550), which was also pointed out in [this comment](https://github.com/deepspeedai/DeepSpeed/pull/6550#issuecomment-2419308783). The issue is that gradients are only copied to CPU when micro_step_id=0. For micro_step_id > 0, the gradients were effectively dropped instead of being accumulated, which leads to an artificially smaller gradient norm. With this fix, gradients are copied and accumulated on every microstep, matching the expected behavior and restoring the correct gradient norm. The plot below shows the impact clearly: the previous implementation significantly underestimates the gradient norm compared to the fixed version. <img width="808" height="584" alt="grad_norms" src="https://github.com/user-attachments/assets/6a0d968c-88cc-4b69-b990-3e2aa1c892b0" /> Setup: SFT run using OpenRLHF with DeepSpeed. - OpenRLHF CPU-offloaded buggy baseline: gradients dropped for microstep > 0 - OpenRLHF CPU-offloaded fixed version: correct accumulation across all microsteps - OpenRLHF GPU, non-offloaded version: reference correct behavior - Verl (FSDP optimizer): additional reference baseline using PyTorch FSDP The fixed version matches non-offloaded DeepSpeed and FSDP, confirming correct gradient accumulation. Effect on loss: <img width="2943" height="1742" alt="loss_cpu_optimizer_comparison" src="https://github.com/user-attachments/assets/edf1dfd7-9b5f-46fe-b174-fcc57b36225c" /> --------- Signed-off-by: Alexis Limozin <alexis@limozin.net>
Author
Parents
Loading