DeepSpeed
5b314f4e - Avoid overwrite of compiled module wrapper attributes (#5549)

Commit
1 year ago
Avoid overwrite of compiled module wrapper attributes (#5549) **Fix overwriting of the compiled wrapper class attributes by those of the wrapped class itself: Copy only those attributes which are not already present in the wrapper.** In the current implementation of the `CompiledModuleWrapper` the wrapper attributes (eg `forward` method) are overwritten by `self._dict_ = module._dict_.copy()`: ``` def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None): class wrapper(mod.__class__): def __init__(self, module, compile_config: Union[CompileConfig, None] = None): self.__dict__ = module.__dict__.copy() ``` This causes the `wrapper`'s `forward` method not being called and, consequently, the wrapped module not compiled. Instead, the wrapped module `forward` method is being called as illustrated in the diagram below (a real scenario from Deespeed-Chat): ![compiled_module_wrapper_bug](https://github.com/microsoft/DeepSpeed/assets/75629718/00eeb3d1-927c-49c7-84ab-f882821cc452) The proposed fix copies only those attributes which are not present in the wrapper class, thus implementing the desired inheritance quality of the wrapper. Attached is a simple reproducer of the problem. [compiled_module_wrapper_bug.zip](https://github.com/microsoft/DeepSpeed/files/15378282/compiled_module_wrapper_bug.zip) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Author
Parents
Loading