Support runtime assertion for inline constraints (#100763)
This pr does the following:
1. previously, inline constraints is not properly set for tensor output data-dependent ops such as a.nonzero because of its return value is not symint. This pr just uses all the unbacked symbols i.e.those start with "i"/"f" in create_unbacked_sym* functions. Note that these symbols are guaranteed to be a super set of inline user constraints.
2. add inline assertions support by checking.
Currently, it only deal with tensor, SymInt, SymFloat, SymBool output data-dependent ops and ignore the rest. It's good enough for now as we only have a limited number of data-dependent ops (.item and .nonzero are explicitly tested).
The examples for graph that is added assertions is shown below:
```
class ExportGraphModule(torch.nn.Module):
def forward(self, x):
arg0: i64[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
nonzero_default: i64[i0, 1] = torch.ops.aten.nonzero.default(arg0); arg0 = None
return pytree.tree_unflatten([nonzero_default], self._out_spec)
class GraphModule(torch.nn.Module):
def forward(self, x):
arg0: i64[s0], = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sym_size: Sym(s0) = torch.ops.aten.sym_size(arg0, 0)
nonzero_default: i64[i1, 1] = torch.ops.aten.nonzero.default(arg0); arg0 = None
sym_size_1: Sym(i1) = torch.ops.aten.sym_size(nonzero_default, 0)
ge: Sym(i1 >= 3) = sym_size_1 >= 3
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(ge); ge = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'nonzero_default.shape[0] is outside of inline constraint [3, 5].'); scalar_tensor_default = None
le: Sym(i1 <= 5) = sym_size_1 <= 5; sym_size_1 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(le); le = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'nonzero_default.shape[0] is outside of inline constraint [3, 5].'); scalar_tensor_default_1 = None
return pytree.tree_unflatten([nonzero_default], self._out_spec)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100763
Approved by: https://github.com/tugsbayasgalan