fix: add setup_context for torch.func compatibility (#7916)
`LinearFunctionForZeroStage3` uses the legacy `forward(ctx, ...)`
pattern which is incompatible with `torch.func` transforms
(`torch.func.grad`, `torch.func.grad_and_value`, `vmap`, etc.):
```
RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.
```
This affects any library that uses `torch.func` internally on a ZeRO-3
model.
## Fix
Fixes #7913
## Note
As pointed out by @zhangj1an in #7913, `PostBackwardFunctionModule` and
`PreBackwardFunctionForModule` in `parameter_offload.py` have the same
issue. Those will be addressed in a follow-up commit within this PR.
---------
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Signed-off-by: Zhang <jianmusings@gmail.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Zhang Jian <jianmusings@gmail.com>
Co-authored-by: zhangj1an <jianmusings@gmail.com>
Co-authored-by: Zhang Jian <zhang.jian@u.nus.edu>
Co-authored-by: Masahiro Tanaka <mtanaka@anyscale.com>