DeepSpeed
6c59d544 - Fix hook count performance regression from v0.18.5 (#7886)

Commit
3 days ago
Fix hook count performance regression from v0.18.5 (#7886) Fixes performance regressions reported in #7882 and #7885. PR #7780 added dynamic hook count computation for reentrant checkpointing correctness, but placed the call inside every gradient hook closure. For a model with n parameter tensors, this creates significant overhead per backward pass. Summary: 1. Added `should_refresh_expected_hook_count()` predicate that returns true only at backward phase boundaries (first hook, or new reentrant phase), so `count_used_parameters_in_backward()` is called once per phase instead of once per hook. 2. Applied this predicate in ZeRO-1/2 (stage_1_and_2.py) and both ZeRO-3 hook sites (stage3.py), reusing the `cached_max_expected_hooks_seen` value when refresh isn't needed. 3. Changed enter_backward() to reset hook counters on first real backward entry, preventing pollution from pre-user-backward autograd calls (e.g., TiledFusedLogitsLoss). With 24-layer transformer, ~267M params (147 parameter tensors), ZeRO-2, 8×H100 80GB, bf16, batch size 8, 20 warmup + 20 measured iterations: - Before fix: 0.1265s/iter - After fix: 0.0505s/iter --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Ramya Ramineni <rraminen@users.noreply.github.com>
Author
Parents
Loading