Better error handling for cond (#108817)
## Exception in cond:
For code below:
```python
import torch
import functorch.experimental.control_flow as control_flow
def true_fn(x):
return x.sin()
def false_fn(x):
return x, x
def f(x, y):
return control_flow.cond(y, true_fn, false_fn, [x])
f(torch.ones(3, 4), torch.tensor(False))
```
The original exception stack trace is:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test_exc.py", line 33, in <module>
f(torch.ones(3, 4), torch.tensor(False))
File "/home/yidi/local/pytorch/test_exc.py", line 31, in f
return control_flow.cond(y, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 154, in cond
return torch.compile(cond_op, backend="eager", fullgraph=True)(
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 365, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 513, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 560, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 197, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 449, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2083, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1164, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 418, in call_function
(false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 410, in speculate_branch
raise UncapturedHigherOrderOpError(
torch._dynamo.exc.UncapturedHigherOrderOpError: Expected branch to return a single tensor
from user code:
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
After this PR we get:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 50, in graph_break_as_hard_error
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 429, in call_function
(false_r, false_graph, false_lifted_freevars) = speculate_branch(False)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 421, in speculate_branch
unimplemented(
File "/home/yidi/local/pytorch/torch/_dynamo/exc.py", line 187, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: Expected branch to return a single tensor
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test_exc.py", line 33, in <module>
f(torch.ones(3, 4), torch.tensor(False))
File "/home/yidi/local/pytorch/test_exc.py", line 31, in f
return control_flow.cond(y, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 154, in cond
return torch.compile(cond_op, backend="eager", fullgraph=True)(
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1159, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars.items)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 53, in graph_break_as_hard_error
raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.
from user code:
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
## Exception during speculating branches
The example code below has a inplace-buffer mutation error,
```python
import torch
import functorch.experimental.control_flow as control_flow
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.ones(6, 4))
def forward(self, x):
def true_fn(x):
self.buffer += 1
return self.buffer.sum() + x.sum()
def false_fn(x):
return (x - 1).sum()
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
mod_for_compile = torch.compile(Foo(), backend="eager", dynamic=True)
mod_for_compile(torch.ones(3, 4))
```
Before this PR the exception looks like:
```python
[2023-09-08 15:20:03,332] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-09-08 15:20:03,332] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Can't inplace modify module params/buffers inside HigherOrderOp
Traceback (most recent call last):
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 163, in speculate_subgraph
output = f.call_function(tx, args, sub_kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 606, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2200, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2316, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1219, in STORE_ATTR
.call_function(self, [obj, ConstantVariable(inst.argval), val], {})
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 618, in call_function
result = handler(tx, *args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 1169, in call_setattr
raise AttributeMutationError(
torch._dynamo.exc.AttributeMutationError: Can't inplace modify module params/buffers inside HigherOrderOp
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 394, in speculate_branch
ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 222, in speculate_subgraph
raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: Can't inplace modify module params/buffers inside HigherOrderOp
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test_exc.py", line 20, in <module>
mod_for_compile(torch.ones(3, 4))
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 365, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 513, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 632, in _convert_frame
result = inner_convert(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 560, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 197, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 482, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 449, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2083, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1124, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
return super().call_function(tx, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 606, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2200, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2316, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 733, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 696, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 397, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1124, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 570, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 415, in call_function
(true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 405, in speculate_branch
raise UncapturedHigherOrderOpError(
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile
from user code:
File "/home/yidi/local/pytorch/test_exc.py", line 16, in forward
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 127, in cond
return cond_op(pred, true_fn, false_fn, operands)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
after this PR, the only difference is the error message of UncapturedHigherOrderOpError changes from `Cond doesn't work unless it is captured completely with torch.compile` to `Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break`.
```python
[2023-09-08 15:17:02,052] [0/0] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-09-08 15:17:02,052] [0/0] torch._dynamo.variables.higher_order_ops: [ERROR] Can't inplace modify module params/buffers inside HigherOrderOp
Traceback (most recent call last):
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 177, in speculate_subgraph
output = f.call_function(tx, args, sub_kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 601, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2300, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1214, in STORE_ATTR
.call_function(self, [obj, ConstantVariable(inst.argval), val], {})
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 618, in call_function
result = handler(tx, *args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 1169, in call_setattr
raise AttributeMutationError(
torch._dynamo.exc.AttributeMutationError: Can't inplace modify module params/buffers inside HigherOrderOp
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 50, in graph_break_as_hard_error
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 426, in call_function
(true_r, true_graph, true_lifted_freevars) = speculate_branch(True)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 410, in speculate_branch
ret_val, ret_graph, ret_lifted_freevars = speculate_subgraph(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 236, in speculate_subgraph
raise Unsupported(
torch._dynamo.exc.Unsupported: speculate_subgraph: while introspecting cond, we were unable to trace function `true_fn` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown. Scroll up for the stack trace of the initial exception. The reason was: Can't inplace modify module params/buffers inside HigherOrderOp
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test_exc.py", line 20, in <module>
mod_for_compile(torch.ones(3, 4))
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 338, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 500, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 634, in _convert_frame
result = inner_convert(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 382, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 484, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 451, in transform
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2088, in run
super().run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 261, in call_function
return super().call_function(tx, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 601, in inline_user_function_return
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2300, in inline_call_
tracer.run()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 728, in run
and self.step()
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 691, in step
getattr(self, inst.opname)(inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 392, in wrapper
return inner_fn(self, inst)
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1119, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 565, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/yidi/local/pytorch/torch/_dynamo/variables/higher_order_ops.py", line 53, in graph_break_as_hard_error
raise UncapturedHigherOrderOpError(reason + msg) from e
torch._dynamo.exc.UncapturedHigherOrderOpError: Cond doesn't work unless it is captured completely with torch.compile. Scroll up to find out what causes the graph break.
from user code:
File "/home/yidi/local/pytorch/test_exc.py", line 16, in forward
return control_flow.cond(x.shape[0] > 4, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 127, in cond
return cond_op(pred, true_fn, false_fn, operands)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108817
Approved by: https://github.com/zou3519