[ONNX] Support registering custom export for prim::PythonOp from torch.autograd.Function (#55630) (#57600)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57600
Demo script:
```python
import torch
class MyReLU(torch.autograd.Function):
staticmethod
def forward(ctx, input, scalar_tuple, scalar, scalar_list):
ctx.save_for_backward(input)
return input.clamp(min=scalar)
staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_a = torch.nn.Linear(2, 2)
self.linear_b = torch.nn.Linear(2, 2)
self.relu = MyReLU.apply
def forward(self, x):
h = self.linear_a(x)
h = self.relu(h, (5, 3), 2, [1, 2, 3])
h = self.linear_b(h)
return h
"""
User define how to export prim::PythonOp into custom op.
"""
def symbolic_pythonop(g, n, *args, **kwargs):
# Print information:
print('arguments of ', kwargs['name'], ':')
print('original node: ', n)
for i, out in enumerate(n.outputs()):
print('original output {}: {}, requires grad: {}'.format(i, out, out.requiresGrad()))
import torch.onnx.symbolic_helper as sym_helper
for i, arg in enumerate(args):
print('arg {}: {}, requires grad: {}'.format(i, arg, arg.requiresGrad() if sym_helper._is_value(arg) else False))
for k, v in kwargs.items():
print('key: ', k, ' v: ', v)
# TODO: all inputs (tensors and scalars) are in args.
# backend can define CustomDomain::PythonOp and how info are stored however it deem fit.
return g.op("CustomDomain::PythonOp", args[0], name_s=kwargs['name'])
torch.onnx.register_custom_op_symbolic("::prim_PythonOp", symbolic_pythonop, 9)
# Define input.
x = torch.tensor([[0.3971, 0.7544],
[0.5695, 0.4388]], requires_grad=True)
model = MyModule()
# Forward.
y = model(x)
torch.onnx.export(model, (x,), 'model.onnx', opset_version=12, verbose=True)
```
Test Plan: Imported from OSS
Reviewed By: malfet
Differential Revision: D28393528
Pulled By: SplitInfinity
fbshipit-source-id: e0d55b7c737c5916fda08a3b26b3306037f970df
Co-authored-by: BowenBao <bowbao@microsoft.com>