DeepSpeed
e4f6da96 - [bugfix] fix partition context unpatch (#7566)

Commit
102 days ago
[bugfix] fix partition context unpatch (#7566) ## Fix asymmetric patching/unpatching in InsertPostInitMethodToModuleSubClasses ### Problem Description The `InsertPostInitMethodToModuleSubClasses` context manager patches `__init__` methods of model classes during entry and unpatches them during exit. However, asymmetric condition checks between patching and unpatching can introduce subtle inheritance bugs. ### Root Cause Analysis The issue occurs with classes that have multiple inheritance where: 1. **Child class A** does not override `__init__` 2. **Parent class B** does not inherit from `nn.Module` 3. **Parent class C** inherits from `nn.Module` **Current asymmetric logic:** ```python # Patching (entry): Only patch classes with explicit __init__ def _enable_class(cls): if '__init__' in cls.__dict__: # ✅ Strict check cls._old_init = cls.__init__ cls.__init__ = partition_after(cls.__init__) # Unpatching (exit): Restore any class with _old_init def _disable_class(cls): if hasattr(cls, '_old_init'): # ❌ Permissive check cls.__init__ = cls._old_init ``` **Execution flow:** 1. **During entry**: Child A is skipped (no explicit `__init__`), Parent C is patched 2. **During exit**: Child A inherits `_old_init` from Parent C and gets incorrectly "restored" **Result**: Child A's `__init__` points to Parent C's original `__init__`, bypassing Parent B and breaking the inheritance chain. ### Reproduction Case This pattern is common in Hugging Face models: ```python class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel): pass # No explicit __init__ # GenericForSequenceClassification - not a nn.Module subclass # Qwen3PreTrainedModel - inherits from nn.Module ``` ### Solution Apply symmetric condition checking in both patch and unpatch operations: ```python def _disable_class(cls): # Match the patching condition: only restore classes we explicitly patched if '__init__' in cls.__dict__ and hasattr(cls, '_old_init'): cls.__init__ = cls._old_init delattr(cls, '_old_init') # Optional cleanup ``` This ensures that only classes that were explicitly patched during entry get restored during exit. ### Testing The fix has been validated against the Qwen3ForSequenceClassification reproduction case and resolves the inheritance chain corruption. ### Related Issues - External issue: https://github.com/modelscope/ms-swift/pull/5820 Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Author
Parents
Loading