pytorch
10bf019b - [jit] Add shapes info to the output type of CallFunction nodes after tracing, if the output is a tensor (#95544)

Commit
2 years ago
[jit] Add shapes info to the output type of CallFunction nodes after tracing, if the output is a tensor (#95544) **Summary**: jit.trace usually adds shape information to all the jit::Values in its graph. This is mostly a side effect of how jit tracing is performed, but many users use this behavior for debugging and for better understanding the graph. Previously, CallFunction nodes (inserted by calling jit.script-ed functions) did _not_ have this information attached. This PR attaches this information for the tensor output values. **Details**: * First the jit tracer sets a global TracerState object * Then the jit tracer invokes the python callable that is to be traced * When the python function gets to a jit.script-ed function, [invokeScriptFunctionFromPython](https://github.com/pytorch/pytorch/blob/8693604bc6274fef8484d556e71b999e1d4d1013/torch/csrc/jit/python/pybind_utils.h#L1060) is called. It inserts a FunctionCall. * Then after the actual scripted function gets called and we have a concrete output, we attach the concrete output [IValue to the TracerState](https://github.com/pytorch/pytorch/blob/8693604bc6274fef8484d556e71b999e1d4d1013/torch/csrc/jit/python/pybind_utils.h#L1001) * ^^ the setValueTrace call (linked in previous list item) is where this PR makes changes; we revise the jit::Value output of the CallFunction node to use the type of the concrete tensor, which will have actual shapes associated. **Test**: added a test verifying that shape info appears in the output type for a CallFunction node in a jit-traced graph. Differential Revision: [D43592880](https://our.internmc.facebook.com/intern/diff/D43592880) Pull Request resolved: https://github.com/pytorch/pytorch/pull/95544 Approved by: https://github.com/qihqi
Author
Committer
Parents
Loading