DeepSpeed
1e85ce2b - [DeepCompile] fix gather params in dynamo skipped frames for ZeRO3 (#8059)

Commit
3 days ago
[DeepCompile] fix gather params in dynamo skipped frames for ZeRO3 (#8059) Fixes #7942 **Root cause:** When `init_z3()` initializes DeepCompile it removes all three parameter-gathering mechanisms (ZeROOrderedDict, module hooks, engine forward hooks) and relies entirely on compiled FX graph ops for allgather/release. but `torch._dynamo` may skip entire frames when it detects graph breaks in for/while loops. Skipped frames execute eagerly with no gathering mechanism, so parameters stay partitioned at shape `[0]`. ## Testing Validated on 2× H200 with ZeRO3 + DeepCompile: | Test | Result | |------|--------| | **Qwen2 MoE** (actual failing model from #7942) | PASS — 5 training steps, no crash | | **LLaMA** (regression — already worked) | PASS | | **Tied embeddings** (shared param across frame types) | PASS | | **Gradient correctness** (loss decreases on fixed input) | PASS — 12.09 → 10.28 | | **Guard stability** (no recompilation loops) | PASS — 22 compilations with fix = 22 without | | **Existing test_compile.py** (100 steps) | PASS | ## Test plan - [x] `pre-commit run` passes on all changed files - [x] Existing `tests/torch_compile/test_compile.py` passes (2 GPU, ZeRO-3) - [x] New regression test `test_deepcompile_skipped_frame.py` passes - [x] Real Qwen2 MoE model trains without crash - [x] Real LLaMA model trains without regression - [x] Tied embedding model trains without crash - [x] Compilation count unchanged vs upstream (no guard instability) cc @tohtana @eternalNight Signed-off-by: ahpoddar <ahpoddar@redhat.com> Co-authored-by: Masahiro Tanaka <mtanaka@anyscale.com>
Author
Parents
Loading