Fix gradient checkpointing with use_reentrant=True / PyTorch-style backward / ZeRO-3 (#7780)
Fixes an issue where ZeRO-3 with gradient checkpointing
(`use_reentrant=True`) and non-scalar backward
(`tensor.backward(gradient=...)`) would fail on subsequent training
iterations with `AttributeError: 'NoneType' object has no attribute
'numel'`.
The root cause was stale parameters remaining in `ipg_buckets` between
iterations due to the checkpoint's backward re-running forward. The fix
clears these buckets in `_pre_step()` before each optimizer step.
Added comprehensive tests for gradient checkpointing with ZeRO-3
covering both `use_reentrant=True` and `use_reentrant=False` modes.
---------
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>