Fix hook count performance regression from v0.18.5 (#7886)
Fixes performance regressions reported in #7882 and #7885.
PR #7780 added dynamic hook count computation for reentrant
checkpointing correctness, but placed the call inside every gradient
hook closure. For a model with n parameter tensors, this creates
significant overhead per backward pass.
Summary:
1. Added `should_refresh_expected_hook_count()` predicate that returns
true only at backward phase boundaries (first hook, or new reentrant
phase), so `count_used_parameters_in_backward()` is called once per
phase instead of once per hook.
2. Applied this predicate in ZeRO-1/2 (stage_1_and_2.py) and both ZeRO-3
hook sites (stage3.py), reusing the `cached_max_expected_hooks_seen`
value when refresh isn't needed.
3. Changed enter_backward() to reset hook counters on first real
backward entry, preventing pollution from pre-user-backward autograd
calls (e.g., TiledFusedLogitsLoss).
With 24-layer transformer, ~267M params (147 parameter tensors), ZeRO-2,
8×H100 80GB, bf16, batch size 8, 20 warmup + 20 measured iterations:
- Before fix: 0.1265s/iter
- After fix: 0.0505s/iter
---------
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Ramya Ramineni <rraminen@users.noreply.github.com>