pytorch
65233456 - Fix missing element types and shapes when autograd.Function has multiple tensor outputs (#57966)

Commit
3 years ago
Fix missing element types and shapes when autograd.Function has multiple tensor outputs (#57966) Summary: When generating IR for autograd.Function, if the function has multiple outputs, a TupleUnpack may be inserted after the original function node, and Pytorch only assigns proper information (tensor element type and shape) to the TupleUnpack and forgets the original function node. In contrast, if autograd.Function only produces one output, the original function node may have tensor element type and shape in its output schema. Before this PR: - (simplified) IR for autograd.Function with one output: input (tensor, dtype=float32, shape=[2, 3]) -> PythonOp -> output (tensor, dtype=float32, shape=[4, 5]) - (simplified) IR for autograd.Function with one output: input (tensor, dtype=float32, shape=[2, 3]) -> PythonOp -> output_0 **(tensor)**, output_1 **(tensor)** -> TupleUnpack output_2 (tensor, dtype=float32, shape=[4, 5]), output_3 (tensor, dtype=float32, shape=[6, 7]) After this PR: - (simplified) IR for autograd.Function with one output: input (tensor, dtype=float32, shape=[2, 3]) -> PythonOp -> output (tensor, dtype=float32, shape=[4, 5]) - (simplified) IR for autograd.Function with one output: input (tensor, dtype=float32, shape=[2, 3]) -> PythonOp ->output_0 **(tensor, dtype=float32, shape=[4, 5])**, output_1 **(tensor, dtype=float32, shape=[6, 7])** -> TupleUnpack output_2 (tensor, dtype=float32, shape=[4, 5]), output_3 (tensor, dtype=float32, shape=[6, 7]) Pull Request resolved: https://github.com/pytorch/pytorch/pull/57966 Reviewed By: zhxchen17 Differential Revision: D30208207 Pulled By: gmagogsfm fbshipit-source-id: 42a3d1f9c0932133112a85df0c49cf4ea0afa175
Author
Committer
Parents
Loading