DeepSpeed
67b9e212 - fix: add setup_context for torch.func compatibility (#7916)

Commit
19 days ago
fix: add setup_context for torch.func compatibility (#7916) `LinearFunctionForZeroStage3` uses the legacy `forward(ctx, ...)` pattern which is incompatible with `torch.func` transforms (`torch.func.grad`, `torch.func.grad_and_value`, `vmap`, etc.): ``` RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. ``` This affects any library that uses `torch.func` internally on a ZeRO-3 model. ## Fix Fixes #7913 ## Note As pointed out by @zhangj1an in #7913, `PostBackwardFunctionModule` and `PreBackwardFunctionForModule` in `parameter_offload.py` have the same issue. Those will be addressed in a follow-up commit within this PR. --------- Signed-off-by: Sung Hyun Cho <hope5487@gmail.com> Signed-off-by: Zhang <jianmusings@gmail.com> Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Signed-off-by: Zhang Jian <jianmusings@gmail.com> Co-authored-by: zhangj1an <jianmusings@gmail.com> Co-authored-by: Zhang Jian <zhang.jian@u.nus.edu> Co-authored-by: Masahiro Tanaka <mtanaka@anyscale.com>
Author
Parents
Loading