Add serialize GraphModule to JSON support (#47612)
Summary:
re-opening PR, missed mypy issues, they are now addressed.
Example:
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
self.e = torch.rand(4)
def forward(self, a, b):
add_1 = a + b
linear = self.linear(add_1)
add_2 = linear + self.e
return add_2
JSON:
{
"modules": {},
"weights": {
"linear.weight": {
"dtype": "torch.float32",
"is_quantized": false,
"shape": "[4, 4]"
},
"linear.bias": {
"dtype": "torch.float32",
"is_quantized": false,
"shape": "[4]"
},
"e": {
"dtype": "torch.float32",
"is_quantized": false,
"shape": "[4]"
}
},
"nodes": [
{
"shape": "[4]",
"dtype": "torch.float32",
"target": "a",
"op_code": "placeholder",
"name": "a",
"args": [],
"kwargs": {}
},
{
"shape": "[4]",
"dtype": "torch.float32",
"target": "b",
"op_code": "placeholder",
"name": "b",
"args": [],
"kwargs": {}
},
{
"shape": "[4]",
"dtype": "torch.float32",
"target": "_operator.add",
"op_code": "call_function",
"name": "add_1",
"args": [
{
"is_node": true,
"name": "a"
},
{
"is_node": true,
"name": "b"
}
],
"kwargs": {}
},
{
"target": "linear",
"op_code": "call_module",
"name": "linear_1",
"args": [
{
"is_node": true,
"name": "add_1"
}
],
"kwargs": {}
},
{
"shape": "[4]",
"dtype": "torch.float32",
"target": "e",
"op_code": "get_attr",
"name": "e",
"args": [],
"kwargs": {}
},
{
"shape": "[4]",
"dtype": "torch.float32",
"target": "_operator.add",
"op_code": "call_function",
"name": "add_2",
"args": [
{
"is_node": true,
"name": "linear_1"
},
{
"is_node": true,
"name": "e"
}
],
"kwargs": {}
},
{
"shape": "[4]",
"dtype": "torch.float32",
"target": "output",
"op_code": "output",
"name": "output",
"args": [
{
"is_node": true,
"name": "add_2"
}
],
"kwargs": {}
}
]
}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47612
Reviewed By: scottxu0730
Differential Revision: D24836223
Pulled By: gcatron
fbshipit-source-id: d3da2b5f90d143beba3b7f1f67462fb7430df906