pytorch
31eb9949 - [dynamo] disallow_in_graph bugfix (#99600)

Commit
1 year ago
[dynamo] disallow_in_graph bugfix (#99600) Testing if the minor change breaks other test cases. For the added test case, TorchDynamo causes graph break on `torch.ops.foo.custom` but then again starts running on the recursively invoked frame - `foo_cpu` on L48 in testfile. This raises assertion like this ~~~ Traceback (most recent call last): File "/scratch/anijain/work/pytorch/test/dynamo/test_decorators.py", line 65, in test_disallow_in_graph_for_custom_op res = opt_fn(x) File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn return fn(*args, **kwargs) File "/scratch/anijain/work/pytorch/test/dynamo/test_decorators.py", line 56, in fn b = torch.ops.foo.custom(a) File "/scratch/anijain/work/pytorch/torch/_ops.py", line 646, in __call__ return self._op(*args, **kwargs or {}) File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 401, in catch_errors return callback(frame, cache_size, hooks, frame_state) File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 495, in _convert_frame result = inner_convert(frame, cache_size, hooks, frame_state) File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 122, in _fn return fn(*args, **kwargs) File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 331, in _convert_frame_assert return _compile( File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper r = func(*args, **kwargs) File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 401, in _compile out_code = transform_code_object(code, transform) File "/scratch/anijain/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object transformations(instructions, code_options) File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 371, in transform tracer = InstructionTranslator( File "/scratch/anijain/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1890, in __init__ self.symbolic_locals = collections.OrderedDict( File "/scratch/anijain/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1893, in <genexpr> VariableBuilder( File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 165, in __call__ return self._wrap(value).clone(**self.options()) File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 290, in _wrap return type_dispatch(self, value) File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 776, in wrap_tensor tensor_variable = wrap_fx_proxy( File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 923, in wrap_fx_proxy return wrap_fx_proxy_cls( File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 983, in wrap_fx_proxy_cls example_value = wrap_to_fake_tensor_and_record( File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 1213, in wrap_to_fake_tensor_and_record fake_e = wrap_fake_exception( File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 835, in wrap_fake_exception return fn() File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 1214, in <lambda> lambda: tx.fake_mode.from_tensor( File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 1434, in from_tensor return self.fake_tensor_converter( File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 329, in __call__ return self.from_real_tensor( File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 283, in from_real_tensor out = self.meta_converter( File "/scratch/anijain/work/pytorch/torch/_subclasses/meta_utils.py", line 531, in __call__ r = self.meta_tensor( File "/scratch/anijain/work/pytorch/torch/_subclasses/meta_utils.py", line 184, in meta_tensor assert not torch._C._dispatch_tls_local_exclude_set().has( AssertionError: ~~~ It seems `_dynamo.disable` is the right option for custom ops added by `torch.library`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/99600 Approved by: https://github.com/jansel
Author
Committer
Parents
Loading