Allow str inputs in non-strict tracing (#120536)
Previously, torch.export in non-strict mode was failing on str inputs while creating fake inputs for tracing (fakify()), and using graph nodes to create constraints. This fixes those 2 stages to allow strs to pass.
Failing test case:
```
class Foo(torch.nn.Module):
def forward(self, a, b, mode):
return torch.div(a, b, rounding_mode=mode)
foo = Foo()
inps = (torch.randn(4, 4), torch.randn(4), "trunc")
exported = export(foo, inps)
with self.assertRaisesRegex(
RuntimeError, "to be equal to trunc, but got floor"
):
_ = exported.module()(torch.randn(4, 4), torch.randn(4), "floor")
self.assertTrue(torch.allclose(exported.module()(*inps), foo(*inps)))
```
Before:
```
(pytorch-local) pianpwk@pianpwk-mbp pytorch % python test/export/test_export_nonstrict.py -k test_runtime_assert_for_prm_str
E
======================================================================
ERROR: test_runtime_assert_for_prm_str_non_strict (__main__.NonStrictExportTestExport.test_runtime_assert_for_prm_str_non_strict)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/pianpwk/Documents/pytorch/torch/testing/_internal/common_utils.py", line 2744, in wrapper
method(*args, **kwargs)
File "/Users/pianpwk/Documents/pytorch/test/export/testing.py", line 40, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/test/export/test_export.py", line 1588, in test_runtime_assert_for_prm_str
exported = export(foo, inps)
^^^^^^^^^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/test/export/test_export_nonstrict.py", line 16, in mocked_non_strict_export
return export(*args, **kwargs, strict=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/torch/export/__init__.py", line 186, in export
return _export(
^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/torch/export/_trace.py", line 541, in wrapper
raise e
File "/Users/pianpwk/Documents/pytorch/torch/export/_trace.py", line 527, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/torch/export/exported_program.py", line 83, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/torch/export/_trace.py", line 707, in _export
) = make_fake_inputs(f, args, kwargs, constraints)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/torch/_export/non_strict_utils.py", line 133, in make_fake_inputs
fake_args, fake_kwargs = tree_map_with_path(
^^^^^^^^^^^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/torch/utils/_pytree.py", line 1519, in tree_map_with_path
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/torch/utils/_pytree.py", line 734, in unflatten
leaves = list(leaves)
^^^^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/torch/utils/_pytree.py", line 1519, in <genexpr>
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/torch/_export/non_strict_utils.py", line 134, in <lambda>
lambda kp, val: fakify(fake_mode, kp, val, t_constraints, sources),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/pianpwk/Documents/pytorch/torch/_export/non_strict_utils.py", line 68, in fakify
raise ValueError("Only tensors allowed as input")
ValueError: Only tensors allowed as input
To execute this test, run the following from the base repo dir:
python test/export/test_export_nonstrict.py -k test_runtime_assert_for_prm_str_non_strict
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 1 test in 0.008s
FAILED (errors=1)
```
After:
```
(pytorch-local) pianpwk@pianpwk-mbp pytorch % python test/export/test_export_nonstrict.py -k test_runtime_assert_for_prm_str
.
----------------------------------------------------------------------
Ran 1 test in 0.237s
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120536
Approved by: https://github.com/tugsbayasgalan, https://github.com/zhxchen17, https://github.com/avikchaudhuri, https://github.com/gmagogsfm