pytorch
6df4da99 - [ONNX] Warning when using __len__ to calculate tensor shape (#55151)

Commit
3 years ago
[ONNX] Warning when using __len__ to calculate tensor shape (#55151) Difference in traced graph and outputs, when using len(tensor) as compared to tensor.shape[0] An example model is (with tensor.shape): ``` # Test len fix with variable inputs import torch import onnxruntime class Model(torch.nn.Module): def forward(self, x): return x.size(1) + x.shape[0] # Call export dummy_x = torch.randn(3, 5) model = Model() import io onnx_io = io.BytesIO() torch.onnx.export(model, (dummy_x,), onnx_io, input_names=['input'], dynamic_axes={'input': {0:'h'}}, verbose=True) # Call onnxruntime runtime and compare outputs on dynamic inputs ort_session = onnxruntime.InferenceSession(onnx_io.getvalue()) x = torch.randn(2, 5) print(model(x)) print(ort_session.run(None, {ort_session.get_inputs()[0].name: x.numpy()})) ``` The output graph is as follows: ``` graph(%input : Float(*, 5, strides=[5, 1], requires_grad=0, device=cpu)): %1 : Long(2, strides=[1], device=cpu) = onnx::Shape(%input) %2 : Long(device=cpu) = onnx::Constant[value={1}]() %3 : Long(device=cpu) = onnx::Gather[axis=0](%1, %2) # test/onnx/test_m.py:9:0 %4 : Long(2, strides=[1], device=cpu) = onnx::Shape(%input) %5 : Long(device=cpu) = onnx::Constant[value={0}]() %6 : Long(device=cpu) = onnx::Gather[axis=0](%4, %5) # test/onnx/test_m.py:9:0 %7 : Long(requires_grad=0, device=cpu) = onnx::Add(%3, %6) # test/onnx/test_m.py:9:0 return (%7) ``` Torch output: 7 ORT output: 7 Now replacing tensor.shape[0] with len(tensor), the graph looks like: ``` graph(%input : Float(*, 5, strides=[5, 1], requires_grad=0, device=cpu)): %1 : Long(2, strides=[1], device=cpu) = onnx::Shape(%input) %2 : Long(device=cpu) = onnx::Constant[value={1}]() %3 : Long(device=cpu) = onnx::Gather[axis=0](%1, %2) # test/onnx/test_m.py:9:0 %4 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={3}]() %5 : Long(requires_grad=0, device=cpu) = onnx::Add(%3, %4) return (%5) ``` Torch output: 7 ORT output: 8 In the case with __len__, %4 is traced as a constant **Workaround to avoid the mismatch when using len to get tensor.shape** Add the following pattern around _export call ``` import builtins len_backup = builtins.len builtins.len = lambda x : x.__len__() # Call export _export(model, args, ..... builtins.len = len_backup ``` Co-authored-by: shubhambhokare1 <shubhambhokare1@gmail.com>
Parents
Loading