[generate_vmap_rule] Add mechanism to override ctx.saved_tensors (CtxWithSavedTensors) (#90964)
As seen in
https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit#heading=h.r3ckcnsh1cxt
This PR creates CtxWithSavedTensors. You can wrap a ctx object in the
backward pass of autograd.Function in CtxWithSavedTensors and specify
the saved_tensors attribute. CtxWithSavedTensor acts like the original
ctx object (all other attribute accesses are forwarded to the original ctx
object) but it has a custom saved_tensors field.
Test Plan:
- tests that you can use CtxWithSavedTensors to get a new object with
your own saved_tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90964
Approved by: https://github.com/samdow, https://github.com/soulitzer