pytorch
4d494986 - [functorch] Refactor life handle storage (#90317)

Commit
2 years ago
[functorch] Refactor life handle storage (#90317) A "life handle" is a pointer-to-boolean that says whether or not a TensorWrapper is alive. A TensorWrapper is alive if we are currently inside of its corresponding transform. An Interpreter is alive if we are currently inside of its corresponding transform. I.e., for vmap(f)(x), the BatchedTensor(x, level=1) is alive inside of the execution of f; and the corresponding VmapInterpreter is alive inside of f. Previously, there was a global map of level to life handle. It is possible to get into a state where we have multiple levels that refer to different Interpreters (if the implementation of an operator calls into functorch) and that messes up the global map. This PR changes it so that - every Interpreter holds a life handle that says if it is alive - to construct a TensorWrapper, one must either (a) directly pass it a life handle, or (b) one must create the TensorWrapper when the corresponding Interpreter is on the stack (and we will automatically grab the life handle by indexing into the DynamicLayerStack with the level) (a) is more robust so I changed most of our C++ callsites to do that. (b) feels a bit hacky to me, but it seems fine for now: - It'll raise a nice error message if the interpreter isn't on the stack - all of our Python callsites already follow this convention (we construct TensorWrappers after pushing the Interpreter onto the stack). The alternative to (b) is that we always do (a), which we can do in the future if (b) runs us into any problems. Test Plan: - all functorch tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/90317 Approved by: https://github.com/samdow
Author
Committer
Parents
Loading