DeepSpeed
87f31d55 - zero3: invalidate coordinator trace on hook re-registration (#8043)

Commit
12 days ago
zero3: invalidate coordinator trace on hook re-registration (#8043) ## Summary Re-registering ZeRO-3 module hooks after they were removed (e.g. via `unwrap_model_for_generation`) leaves the param coordinator's recorded trace stale. The next training forward raises `IndexError: pop from an empty deque` from `_start_of_forward_hook -> reset_step -> record_parameters -> popleft`. ## Repro DeepSpeed master, torch 2.8.0+cu128, transformers, peft. Single GPU. ```python import torch, deepspeed from deepspeed.runtime.zero import unwrap_model_for_generation from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, get_peft_model, TaskType m = "hf-internal-testing/tiny-random-gpt2" tok = AutoTokenizer.from_pretrained(m); tok.pad_token = tok.eos_token model = get_peft_model(AutoModelForCausalLM.from_pretrained(m, dtype=torch.bfloat16), LoraConfig(task_type=TaskType.CAUSAL_LM, r=4, target_modules=["c_attn"])) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) cfg = {"train_micro_batch_size_per_gpu": 1, "bf16": {"enabled": True}, "zero_optimization": {"stage": 3, "stage3_param_persistence_threshold": 0}, "optimizer": {"type": "Adam", "params": {"lr": 1e-3}}} engine, *_ = deepspeed.initialize(model=model, config=cfg, model_parameters=[p for p in model.parameters() if p.requires_grad]) ids = tok("hello", return_tensors="pt").input_ids.to(engine.device) for _ in range(2): with unwrap_model_for_generation(engine) as unwrapped: with torch.no_grad(): unwrapped.generate(ids, max_new_tokens=4, do_sample=False, pad_token_id=tok.pad_token_id) out = engine(input_ids=ids, labels=ids) engine.backward(out.loss); engine.step() ``` Run with ``torchrun --nproc-per-node=1 repro.py``. Second iteration raises the IndexError. ## Fix Two small edits in ``deepspeed/runtime/zero/``: - ``parameter_offload.py::_register_deepspeed_module``: when the root module is re-registered, invalidate the coordinator trace so the next forward re-records cleanly. - ``partitioned_param_coordinator.py::_clear_trace_structures``: also clear ``__step_id_module_fetched_for``, which was being left populated and caused the empty-deque pop. Both guards are no-ops on initial registration (trace is already INVALID) and on non-root submodule walks. ## Test ``tests/unit/runtime/zero/test_unwrap_model.py::TestUnwrapModelTraceInvalidate`` covers the path: run one training step, wrap with ``unwrap_model_for_generation``, assert the coordinator returns to INVALID. World size 2. --------- Signed-off-by: Sung Hyun Cho <hope5487@gmail.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Author
Parents
Loading