pytorch
afa489de - [ONNX] Enable lower_tuple pass for custom layer (#41548)

Commit
4 years ago
[ONNX] Enable lower_tuple pass for custom layer (#41548) Summary: Custom layer by `torch.autograd.Function` appears in the lower_tuple as `prim::PythonOp`. Adding this op type to the allowed list to enable lower_tuple pass. This helps with exporting custom layer with tuple outputs. E.g. ```python import torch class CustomFunction(torch.autograd.Function): staticmethod def symbolic(g, input): return g.op('CustomNamespace::Custom', input, outputs=2) staticmethod def forward(ctx, input): return input, input class Custom(torch.nn.Module): def forward(self, input): return CustomFunction.apply(input) model = Custom() batch = torch.FloatTensor(1, 3) torch.onnx.export(model, batch, "test.onnx", verbose=True) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/41548 Reviewed By: glaringlee Differential Revision: D22926143 Pulled By: bzinodev fbshipit-source-id: ce14d1d3c70a920154a8235d635ab31ddf0c46f3
Author
Parents
Loading