`symbolic_shape_infer.py`: Fix slicing a tensor that has a sympy.Min() in its shape (#14384)
### Description
`_infer_Slice()` is a function (arguably the most complex one) in
`symbolic_shape_infer.py` that infers the shape of the output of a
`Slice` node. This commit fixes an edge case in `_infer_Slice()` caused
by a SymPy quirk.
When both the end of the slice (let's call it `e`) and the corresponding
dimension of the sliced tensor (let's call it `dim`) are arbitrary
symbolic expressions, `symbolic_shape_infer.py`
[checks](https://github.com/microsoft/onnxruntime/blob/de7a868d5f3390d7c095a53c26abd39f402f3f93/onnxruntime/python/tools/symbolic_shape_infer.py#L1728)
if `e <= dim`. Comparing symbolic expressions is hard in general, so if
the comparison fails, `symbolic_shape_infer.py` [gives
up](https://github.com/microsoft/onnxruntime/blob/de7a868d5f3390d7c095a53c26abd39f402f3f93/onnxruntime/python/tools/symbolic_shape_infer.py#L1734)
and assumes that `e` is equal to `dim`.
A failure of this sort currently happens for expressions of the form `Y
- X >= 0` where `Y` contains a `sympy.Min()` (`symbolic_shape_infer.py`
tries to rewrite `X <= Y` comparisons in various ways, and `Y - X >= 0`
is [one of
them](https://github.com/microsoft/onnxruntime/blob/de7a868d5f3390d7c095a53c26abd39f402f3f93/onnxruntime/python/tools/symbolic_shape_infer.py#L1664)).
An simple example to illustrate this:
```python
>>> import sympy
>>> X = sympy.Symbol('X', positive=True, integer=True)
>>>
>>> y1 = 9999
>>> Y1 = X + y1 - 5000
>>> bool(Y1 - X >= 0)
True
>>>
>>> y2 = X + 4999
>>> Y2 = X + y2 - 5000
>>> bool(Y2 - X >= 0)
True
>>>
>>> y3 = sympy.Min(y1, y2)
>>> Y3 = X + y3 - 5000
>>> bool(Y3 - X >= 0)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File ".../venv/lib/python3.9/site-packages/sympy/core/relational.py", line 511, in __bool__
raise TypeError("cannot determine truth value of Relational")
TypeError: cannot determine truth value of Relational
```
If you assume that `X` is positive symbol (`symbolic_shape` [does
assume](https://github.com/microsoft/onnxruntime/blob/de7a868d5f3390d7c095a53c26abd39f402f3f93/onnxruntime/python/tools/symbolic_shape_infer.py#L2129)
this for graph inputs), then both `Y1 >= X` and `Y2 >= X` holds, and
SymPy can prove this. This means that `Y3 >= X` also holds (since `Y3`
is essentially equal to either `Y1` or `Y2`, depending on the value of
`X`), but this is too hard for SymPy to prove. I confirmed that this is
still the case for the latest SymPy version (`1.11.1`).
This commit tries to fix this edge case by slightly rewriting the
expression containing `sympy.Min()`. I explain the details in the
comments in `symbolic_shape_infer.py`, so I won't duplicate them in the
PR description.
### Motivation and Context
This sounds like a very contrived example, but it actually appeared in
the wild when we tried to infer shapes for an ONNX graph exported from
PyTorch that used relative-position multihead attention from Fairseq.
The problematic line is
[here](https://github.com/facebookresearch/fairseq/blob/7d050ada7d365b535bf7c634ed3bcaf1cc2930b1/fairseq/modules/espnet_multihead_attention.py#L192).
In our codebase, we have something like `matrix_bd = matrix_bd[:, :, :,
: matrix_ac.size(-1)]` before we add `matrix_ac` and `matrix_bd`.
`matrix_bd` is itself a result of another slice, hence its shape
contains `sympy.Min()`, and the SymPy weirdness described above prevents
`symbolic_shape_infer.py` from correctly inferring the final shape of
`matrix_bd`. Then `symbolic_shape_infer.py` explodes when we try to add
`matrix_ac` and `matrix_bd`, because their shapes are not compatible.
I added a small self-contained unit test to illustrate the problem.
*Without* the fix, `slice_out_cropped` has shape `[N + Min(42, N + 21) -
22]`, and `input` has shape `[N]`, and we get this:
```
> python onnxruntime_test_python_symbolic_shape_infer.py
..................Cannot determine if 22 - N < 0
Unable to determine if N <= N + Min(42, N + 21) - 22, treat as equal
E....
======================================================================
ERROR: test_slice_of_min (__main__.TestSymbolicShapeInferenceForSlice)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/dfyz/onnxruntime/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py", line 460, in test_slice_of_min
model = SymbolicShapeInference.infer_shapes(onnx.helper.make_model(graph_def))
File "/home/dfyz/onnxruntime/onnxruntime/test/python/../../python/tools/symbolic_shape_infer.py", line 2461, in infer_shapes
raise Exception("Incomplete symbolic shape inference")
Exception: Incomplete symbolic shape inference
----------------------------------------------------------------------
Ran 23 tests in 0.486s
FAILED (errors=1)
```
*With* the fix, both tensors have shape `[N]`, and the test passes.
---------
Co-authored-by: Ivan Komarov <dfyz@yandex-team.ru>