pytorch
571f96bf - cudagraph trees (#89146)

Commit
1 year ago
cudagraph trees (#89146) CUDA Graph Trees Design doc: https://docs.google.com/document/d/1ZrxLGWz7T45MSX6gPsL6Ln4t0eZCSfWewtJ_qLd_D0E/edit Not currently implemented : - Right now, we are using weak tensor refs from outputs to check if a tensor has dies. This doesn't work because a) aliasing, and b) aot_autograd detaches tensors (see note [Detaching saved tensors in AOTAutograd]). Would need either https://github.com/pytorch/pytorch/issues/91395 to land to use storage weak refs or manually add a deleter fn that does what I want. This is doable but theres some interactions with the caching allocator checkpointing so saving for a stacked pr. - Reclaiming memory from the inputs during model recording. This isn't terribly difficult but deferring to another PR. You would need to write over the input memory during warmup, and therefore copy the inputs to cpu. Saving for a stacked pr. - Warning on overwriting previous generation outputs. and handling nested torch.compile() calls in generation tracking Differential Revision: [D43999887](https://our.internmc.facebook.com/intern/diff/D43999887) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89146 Approved by: https://github.com/ezyang
Author
Committer
Parents
Loading