zero3: invalidate coordinator trace on hook re-registration (#8043)
## Summary
Re-registering ZeRO-3 module hooks after they were removed (e.g. via
`unwrap_model_for_generation`) leaves the param coordinator's recorded
trace stale. The next training forward raises `IndexError: pop from an
empty deque` from `_start_of_forward_hook -> reset_step ->
record_parameters -> popleft`.
## Repro
DeepSpeed master, torch 2.8.0+cu128, transformers, peft. Single GPU.
```python
import torch, deepspeed
from deepspeed.runtime.zero import unwrap_model_for_generation
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
m = "hf-internal-testing/tiny-random-gpt2"
tok = AutoTokenizer.from_pretrained(m); tok.pad_token = tok.eos_token
model = get_peft_model(AutoModelForCausalLM.from_pretrained(m, dtype=torch.bfloat16),
LoraConfig(task_type=TaskType.CAUSAL_LM, r=4, target_modules=["c_attn"]))
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
cfg = {"train_micro_batch_size_per_gpu": 1, "bf16": {"enabled": True},
"zero_optimization": {"stage": 3, "stage3_param_persistence_threshold": 0},
"optimizer": {"type": "Adam", "params": {"lr": 1e-3}}}
engine, *_ = deepspeed.initialize(model=model, config=cfg,
model_parameters=[p for p in model.parameters() if p.requires_grad])
ids = tok("hello", return_tensors="pt").input_ids.to(engine.device)
for _ in range(2):
with unwrap_model_for_generation(engine) as unwrapped:
with torch.no_grad():
unwrapped.generate(ids, max_new_tokens=4, do_sample=False, pad_token_id=tok.pad_token_id)
out = engine(input_ids=ids, labels=ids)
engine.backward(out.loss); engine.step()
```
Run with ``torchrun --nproc-per-node=1 repro.py``. Second iteration
raises the IndexError.
## Fix
Two small edits in ``deepspeed/runtime/zero/``:
- ``parameter_offload.py::_register_deepspeed_module``: when the root
module is re-registered, invalidate the coordinator trace so the next
forward re-records cleanly.
- ``partitioned_param_coordinator.py::_clear_trace_structures``: also
clear ``__step_id_module_fetched_for``, which was being left populated
and caused the empty-deque pop.
Both guards are no-ops on initial registration (trace is already
INVALID) and on non-root submodule walks.
## Test
``tests/unit/runtime/zero/test_unwrap_model.py::TestUnwrapModelTraceInvalidate``
covers the path: run one training step, wrap with
``unwrap_model_for_generation``, assert the coordinator returns to
INVALID. World size 2.
---------
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>