Properly propagates checkpoint wrapper args and kwargs (#99791)
It looks like passing `*args` and `**kwargs` to `checkpoint_wrapper()` does not work because someone forgot some `*`s. This adds them back in.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99791
Approved by: https://github.com/awgu