pytorch
2eea3cb1 - Fix composable `checkpoint(use_reentrant=True)` with multi args (#103590)

Commit
2 years ago
Fix composable `checkpoint(use_reentrant=True)` with multi args (#103590) The `_ModuleHookCheckpointFunction.backward()` should take in `*output_grads` instead of `output_grads`. Otherwise, we may see an error like: ``` TypeError: backward() takes 2 positional arguments but 5 were given ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/103590 Approved by: https://github.com/rohan-varma, https://github.com/fduwjj, https://github.com/fegin
Author
Andrew Gu
Committer
Parents
Loading