pytorch
c304fddf - [dynamo][numpy] Support graph break for numpy ndarray (#100839)

Commit
1 year ago
[dynamo][numpy] Support graph break for numpy ndarray (#100839) Issue: #93684 In previous PRs #95849 #99560 we redirect `numpy.*`, `<tensor>.numpy()` calls to `torch_np.*` methods and attributes, by creating `NumpyNdarrayVariable` for those calls. We need to handle `NumpyNdarrayVariable` when graph break happens. This PR did 2 things: 1. In `codegen.py` we made sure we can reconstruct the value wrapped by `NumpyNdarrayVariable`, to be `torch_np.ndarray` in the stack whenerver we recompiles the subgraph. 2. In `builder.py` we can wrap the value to be `NumpyNdarrayVariable` and save it as graph input. ----- Starting from commit 6: ## A new design for supporting numpy in dynamo In short the core concept doesn't change: we still convert `numpy` API calls to `torch_np` API calls. However, instead of wrapping a `torch_np.ndarray` in `NumpyNdarrayVariable`, the new design wraps a `torch.Tensor`. The reason for doing this change is because we need to keep `torch.Tensor` everywhere in the captured graph, so that it works well with the backend of dynamo. See discussions in https://github.com/Quansight-Labs/numpy_pytorch_interop/issues/142 for details. ### Flow This is an example showing how do we think about dynamo working on a simple function: ```python def f(x: torch.Tensor, y: torch.Tensor): a, b = x.numpy(), y.numpy() c = np.add(x, y) return torch.from_numpy(c) ``` ``` +------------+ +------------+ torch.Tensor | |numpy.ndarray| | -------------- .numpy() --------------| | | | | | +------------------+ +------------+ | numpy.add |numpy.ndarray| |torch.Tensor +------------+ | --------------| torch.from_numpy -------------- torch.Tensor | |numpy.ndarray| | | | -------------- .numpy() --------------| | +------------------+ | | | | +------------+ +------------+ +------------+ +----------------+ torch.Tensor | |torch.Tensor | | -------------- .detach() --------------| | | | | | +----------------+ +------------+ +------------+ | |torch_np.ndarray| |torch.Tensor| |torch.Tensor | torch_np.add -----------------| util.to_tensor -------------| .detach() -------------- +------------+ | | | | | | torch.Tensor | |torch.Tensor | | +----------------+ +------------+ -------------- .detach() --------------| | | | | | +------------+ | +----------------+ | | wrapper on torch_np.add | +--------------------------------------------------------+ ``` ### Approach `torch_np` APIs can take both `torch_np.ndarray` as well as `torch.Tensor`. What we need to do is to have a wrapper for these APIs to convert the return value back to `torch.Tensor`. This way only the wrapper is showing up in the captured graph, with `torch.Tensor`s as input and `torch.Tensor` as output. If we have a graph break or we've traced to the end of the program, we need to inspect all the `NumpyNdarrayVariable` in the stack and convert them back to `numpy.ndarray`, to make sure the compiled version is still behaving the same as the eager version. ### Examples Here's an example of the graph generated: ```python def fn(x: np.ndarray, y: np.ndarray): a = x.real b = y.real torch._dynamo.graph_break() return np.add(a, 1), np.add(b, 1) ``` Graph generated: ``` [2023-05-16 10:31:48,737] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs ------------- -------------- ---------------------------------------------------------- ---------------------- -------- placeholder l_x_ L_x_ () {} placeholder l_y_ L_y_ () {} call_function from_numpy <built-in method from_numpy of type object at 0x12b1fdc80> (l_x_,) {} call_function from_numpy_1 <built-in method from_numpy of type object at 0x12b1fdc80> (l_y_,) {} call_function attr_wrapper <function attr_wrapper at 0x12e8693a0> (from_numpy, 'real') {} call_function attr_wrapper_1 <function attr_wrapper at 0x12e8693a0> (from_numpy_1, 'real') {} output output output ((),) {} [2023-05-16 10:31:48,908] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH __compiled_fn_2 <eval_with_key>.1 opcode name target args kwargs ------------- ------------- ---------------------------------------------------------- ------------------------------- -------- placeholder l_a_ L_a_ () {} placeholder l_b_ L_b_ () {} call_function from_numpy <built-in method from_numpy of type object at 0x12b1fdc80> (l_a_,) {} call_function from_numpy_1 <built-in method from_numpy of type object at 0x12b1fdc80> (l_b_,) {} call_function wrapped_add <Wrapped function <original add>> (from_numpy, 1) {} call_function wrapped_add_1 <Wrapped function <original add>> (from_numpy_1, 1) {} output output output ((wrapped_add, wrapped_add_1),) {} ``` ### Changes * `codegen.py`: reconstruct `numpy.ndarray` from `NumpyNdarrayVariable` by adding bytecode to call `utils.to_numpy_helper()`. * `output_graph.py`: getting rid of legacy code that does exactly what `codegen.py` does, which only handling return case but not graph break case. * `utils.py`: added helpers to convert `numpy.ndarray` to `torch.Tensor` and vice versa. Also adding a wrapper class that takes in a function. In `__call__` it calls the function and converts its out to `torch.Tensor` (or a list of it). * `builder.py`: add method to wrap `numpy.ndarray` graph inputs into `NumpyNdarrayVariable`, by calling `torch.numpy` in the proxy. * `misc.py`: `numpy` API calls goes into `NumpyVariable` and we find the function with the same name in `torch_np` module, then wrap it with the wrapper defined in `utils.py`. * `tensor.py`, `torch.py`: proxy `tensor.numpy()` to be `torch.detach()` but wrap it with `NumpyNdarrayVariable`. Similarly, `torch.from_numpy()` -> `torch.detach()` but wrap it with `TensorVariable`. In `NumpyNdarrayVariable`, do the similar `torch_np.ndarray` to `torch.Tensor` wrapping for attributes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100839 Approved by: https://github.com/ezyang
Author
Committer
Parents
Loading