DeepSpeed
311674ff - Fix gradient checkpointing with use_reentrant=True / PyTorch-style backward / ZeRO-3 (#7780)

Commit
6 days ago
Fix gradient checkpointing with use_reentrant=True / PyTorch-style backward / ZeRO-3 (#7780) Fixes an issue where ZeRO-3 with gradient checkpointing (`use_reentrant=True`) and non-scalar backward (`tensor.backward(gradient=...)`) would fail on subsequent training iterations with `AttributeError: 'NoneType' object has no attribute 'numel'`. The root cause was stale parameters remaining in `ipg_buckets` between iterations due to the checkpoint's backward re-running forward. The fix clears these buckets in `_pre_step()` before each optimizer step. Added comprehensive tests for gradient checkpointing with ZeRO-3 covering both `use_reentrant=True` and `use_reentrant=False` modes. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Author
Parents
Loading