pytorch
a47f93b6 - Add type and shape annotation for gm.print_readable() (#86562)

Commit
2 years ago
Add type and shape annotation for gm.print_readable() (#86562) For ``` def f(a, b): dim0 = a.shape[0] + b.shape[0] dim1 = a.shape[1] + b.shape[1] d = a.new_empty(dim0, dim1) return d fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3)) fx_g.print_readable() ``` Tracing with 'real' and 'fake' mode yields ``` class f(torch.nn.Module): def forward(self, a_1: Tensor<f32>[5, 3], b_1: Tensor<f32>[4, 3]): # No stacktrace found for following nodes new_empty: Tensor<f32>[9, 6] = torch.ops.aten.new_empty.default(a_1, [9, 6], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = None return new_empty ``` Tracing with 'symbolic' mode yields ``` def forward(self, a_1: Tensor<f32>[t0.size(0), t0.size(1)], b_1: Tensor<f32>[t1.size(0), t0.size(1)]): # No stacktrace found for following nodes sym_size: Symint(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0) sym_size_1: Symint(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0) add: Symint(t0.size(0) + t1.size(0)) = sym_size + sym_size_1; sym_size = sym_size_1 = None sym_size_2: Symint(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1) sym_size_3: Symint(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1); b_1 = None add_1: Symint(2*t0.size(1)) = sym_size_2 + sym_size_3; sym_size_2 = sym_size_3 = None new_empty: Tensor<f32>[t0.size(0) + t1.size(0), 2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False); a_1 = add = add_1 = None return new_empty ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/86562 Approved by: https://github.com/Chillee
Author
Committer
Parents
Loading