Fix cond branches take no arguments (#109308)
For code like this:
```python
import torch
from functorch.experimental import control_flow
def exportdb_example2(x):
def true_fn():
return torch.sin(x)
def false_fn():
return torch.cos(x)
return control_flow.cond(x.sum() > 0, true_fn, false_fn, [])
ep = torch._export.export(exportdb_example2, (torch.randn(4, 5),))
```
before the pr, when the branches take an empty/list of tuple as inputs, we'll have error like following:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/test_cond.py", line 11, in <module>
ep = torch._export.export(exportdb_example2, (torch.randn(4, 5),))
File "/home/yidi/local/pytorch/torch/_export/__init__.py", line 340, in export
gm_torch_level, _ = torch._dynamo.export(
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 1207, in inner
result_traced = opt_f(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/test_cond.py", line 3, in exportdb_example2
def exportdb_example2(x):
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 1173, in result_capturing_wrapper
graph_captured_result = torch.func.functional_call(
File "/home/yidi/local/pytorch/torch/_functorch/functional_call.py", line 143, in functional_call
return nn.utils.stateless._functional_call(
File "/home/yidi/local/pytorch/torch/nn/utils/stateless.py", line 264, in _functional_call
return module(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 725, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 305, in __call__
raise e
File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 292, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
return forward_call(*args, **kwargs)
File "<eval_with_key>.2", line 10, in forward
File "/home/yidi/local/pytorch/torch/_ops.py", line 301, in __call__
return wrapper()
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_ops.py", line 297, in wrapper
return self.dispatch(
File "/home/yidi/local/pytorch/torch/_ops.py", line 280, in dispatch
return kernel(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_higher_order_ops/utils.py", line 52, in inner
return autograd_not_implemented_inner(op, deferred_error, *args, **kwargs)
File "/home/yidi/local/pytorch/torch/_higher_order_ops/utils.py", line 25, in autograd_not_implemented_inner
result = operator(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_ops.py", line 301, in __call__
return wrapper()
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_ops.py", line 297, in wrapper
return self.dispatch(
File "/home/yidi/local/pytorch/torch/_ops.py", line 255, in dispatch
return self.python_key_mode_table[type(curr_mode)](*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 310, in cond_fake_tensor_mode
flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 725, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 305, in __call__
raise e
File "/home/yidi/local/pytorch/torch/fx/graph_module.py", line 292, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/nn/modules/module.py", line 1528, in _call_impl
return forward_call(*args, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given
```
Thanks for @williamwen42 spotting this error! We fix it by addressing the case when add_after is -1.
Test Plan:
See newly added tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109308
Approved by: https://github.com/williamwen42