[ONNX] Warning when using __len__ to calculate tensor shape (#55151)
Difference in traced graph and outputs, when using len(tensor) as compared to tensor.shape[0]
An example model is (with tensor.shape):
```
# Test len fix with variable inputs
import torch
import onnxruntime
class Model(torch.nn.Module):
def forward(self, x):
return x.size(1) + x.shape[0]
# Call export
dummy_x = torch.randn(3, 5)
model = Model()
import io
onnx_io = io.BytesIO()
torch.onnx.export(model, (dummy_x,), onnx_io,
input_names=['input'],
dynamic_axes={'input': {0:'h'}},
verbose=True)
# Call onnxruntime runtime and compare outputs on dynamic inputs
ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
x = torch.randn(2, 5)
print(model(x))
print(ort_session.run(None, {ort_session.get_inputs()[0].name: x.numpy()}))
```
The output graph is as follows:
```
graph(%input : Float(*, 5, strides=[5, 1], requires_grad=0, device=cpu)):
%1 : Long(2, strides=[1], device=cpu) = onnx::Shape(%input)
%2 : Long(device=cpu) = onnx::Constant[value={1}]()
%3 : Long(device=cpu) = onnx::Gather[axis=0](%1, %2) # test/onnx/test_m.py:9:0
%4 : Long(2, strides=[1], device=cpu) = onnx::Shape(%input)
%5 : Long(device=cpu) = onnx::Constant[value={0}]()
%6 : Long(device=cpu) = onnx::Gather[axis=0](%4, %5) # test/onnx/test_m.py:9:0
%7 : Long(requires_grad=0, device=cpu) = onnx::Add(%3, %6) # test/onnx/test_m.py:9:0
return (%7)
```
Torch output: 7
ORT output: 7
Now replacing tensor.shape[0] with len(tensor), the graph looks like:
```
graph(%input : Float(*, 5, strides=[5, 1], requires_grad=0, device=cpu)):
%1 : Long(2, strides=[1], device=cpu) = onnx::Shape(%input)
%2 : Long(device=cpu) = onnx::Constant[value={1}]()
%3 : Long(device=cpu) = onnx::Gather[axis=0](%1, %2) # test/onnx/test_m.py:9:0
%4 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={3}]()
%5 : Long(requires_grad=0, device=cpu) = onnx::Add(%3, %4)
return (%5)
```
Torch output: 7
ORT output: 8
In the case with __len__, %4 is traced as a constant
**Workaround to avoid the mismatch when using len to get tensor.shape**
Add the following pattern around _export call
```
import builtins
len_backup = builtins.len
builtins.len = lambda x : x.__len__()
# Call export
_export(model, args, .....
builtins.len = len_backup
```
Co-authored-by: shubhambhokare1 <shubhambhokare1@gmail.com>