pytorch
67bd2a31 - [Reland] Add python mode (#64360)

Commit
3 years ago
[Reland] Add python mode (#64360) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360 This PR adds a (private) enable_python_mode context manager. (see torch/utils/_python_dispatch.py). enable_python_mode accepts the type of a __torch_dispatch__ object as its argument. Whenever an operator gets called inside of the context manager, it dispatches to the __torch_dispatch__ of the passed-in type. Example usage: ``` with enable_python_mode(LoggingTensor): z = torch.empty([]) assert isinstance(z, LoggingTensor) ``` There are quite a few changes that were made to support this. First, we added TorchDispatchTypeObject, a C++ struct that represents the type of a `__torch_dispatch__` object (e.g. LoggingTensor). It holds both the PyObject* representing the class and a PyInterpreter* so we know which Python interpreter it came from. Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this is null, dispatching happens as usual. When it is non-null, we prepend the TorchDispatchTypeObject's PyObject* to the overloaded args list so that it is considered first for dispatch. To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser` works. The "overloaded args list" previously only consisted of Tensor PyObjects, but now it can have types in addition to Tensors! - We renamed `append_overloaded_arg` to `append_overloaded_arg` - We added a new `append_overloaded_type` that appends a type to overloaded_args - We added special handling in `handle_torch_dispatch_no_python_arg_parser` and `append_overloaded_arg` to handle types in addition to Tensors. Then, there is PythonMode and PythonModeTLS. - We reuse the DispatchKey::Python dispatch key as a mode key - We use PythonMode::enter and PythonMode::exit to enable/disable DispatchKey::Python and set the PythonModeTLS. - PythonModeTLS stores a TorchDispatchTypeObject as metadata. - PythonMode is in libtorch_python, and PythonModeTLS is in ATen. This split is due to the libtorch_python library boundary (because we need to save TLS in ATen/ThreadLocalState) - We modify the PythonFallbackKernel to look up the relevant TorchDispatchTypeObject (if Python Mode is active) and dispatch using it. There are two more miscellaneous changes: - internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an exclude guard. enable_python_mode currently does not handle torch.tensor and the exclude guard is to prevent a bug. Future: - This PR does not allow for the nesting of Python modes. In the future we should be able to enable this with a more sane no_dispatch API and by changing the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing. Test Plan: - new tests Reviewed By: ezyang Differential Revision: D30698082 Pulled By: zou3519 fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
Author
Parents
Loading