sym_max/sym_min introduce guard if hinted (#94400)
This patch started with only the change in `torch/_prims_common/__init__.py`. Unfortunately, this change by itself fails tests. The reason it fails tests is sym_max produces sympy.Max expression, which impedes our ability to actually reason symbolically about the resulting expressions. We much prefer to insert a guard on `l > 1` and get a Sympy expression without Max in it, if we can. In the upcoming unbacked SymInts PR, we can't necessarily do this, but without unbacked SymInts, we always can.
To do this, we introduce `alternate_impl_if_hinted_methods`. The idea is that if all of the arguments into max/min have hints, we will just go ahead and introduce a guard and then return one argument or the other, depending on the result. This is done by rewrapping the SymNode into SymInt/SymFloat and then running builtins.min/max, but we also could have just manually done the guarding (see also https://github.com/pytorch/pytorch/pull/94365 )
However, a very subtle problem emerges when you do this. When we do builtins min/max, we return the argument SymNode directly, without actually allocating a fresh SymNode. Suppose we do a min-max with a constant (as is the case in `sym_max(l, 1)`. This means that we can return a constant SymNode as the result of the computation. Constant SymNodes get transformed into regular integers, which then subsequently trigger the assert at https://github.com/pytorch/pytorch/pull/94400/files#diff-03557db7303b8540f095b4f0d9cd2280e1f42f534f67d8695f756ec6c02d3ec7L620
After thinking about this a bit, I think the assert is wrong. It should be OK for SymNode methods to return constants. The reason the assert was originally added was that ProxyTensorMode cannot trace a constant return. But this is fine: if you return a constant, no tracing is necessary; you know you have enough guards that it is guaranteed to be a constant no matter what the input arguments are, so you can burn it in. You might also be wondering why a change to SymNode method affects the assert from the dispatch mode dispatch: the call stack typically looks like SymNode.binary_magic_impl -> SymProxyTensorMode -> SymNode.binary_magic_impl again; so you hit the binary_magic_impl twice!
No new tests, the use of sym_max breaks preexisting tests and then the rest of the PR makes the tests pass again.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94400
Approved by: https://github.com/Chillee