pytorch
53c94ef1 - [generate_vmap_rule] Add mechanism to override ctx.saved_tensors (CtxWithSavedTensors) (#90964)

Commit
2 years ago
[generate_vmap_rule] Add mechanism to override ctx.saved_tensors (CtxWithSavedTensors) (#90964) As seen in https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit#heading=h.r3ckcnsh1cxt This PR creates CtxWithSavedTensors. You can wrap a ctx object in the backward pass of autograd.Function in CtxWithSavedTensors and specify the saved_tensors attribute. CtxWithSavedTensor acts like the original ctx object (all other attribute accesses are forwarded to the original ctx object) but it has a custom saved_tensors field. Test Plan: - tests that you can use CtxWithSavedTensors to get a new object with your own saved_tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90964 Approved by: https://github.com/samdow, https://github.com/soulitzer
Author
Committer
Parents
Loading