pytorch
064ae9ff - Support register_hook on input tensors (#108903)

Commit
1 year ago
Support register_hook on input tensors (#108903) The strategy in this PR is pretty straightforward. There are 2 kinds of hooks: 1) Hooks on objects with sources (inputs, params) 2) Hooks on objects w/o sources (intermediaries, and outputs). Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced). The plan: **For tensors w/ a source:** We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call `register_hook`. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke `register_hook` on. As long as we guard on the identity of the lifted function, this is sound to do. **For tensors w/o a source:** Graph break - we will support this in a subsequent PR **Handles:** An interesting new component here is the creation of a `STORE_FAST `->`LOAD_FAST` associated with the handle, the return result of `register_hook`. If the user code stored the result of `register_hook` in a handle, we need to honor that. We do so by interceding into `STORE_FAST`, and recording the name of the local variable as directed by user code. We then honor that same name in the reconstructed bytecode. If the user did not store a hook, we merely pop the produced value to preserve the stack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108903 Approved by: https://github.com/ezyang ghstack dependencies: #108846, #109092
Author
Committer
Parents
Loading