[cond] Add more tests for valid inputs of cond (#110727)
This PR adds a parametrized test for cond. It tests cond can be traced with valid inputs. Specifically valid inputs is combination of:
- pred (python boolean, boolean tensor, int tensor, scalar tensor)
- true_fn/false_fn (func, obj, nn_module)
- Operands (0 or more tensor inputs), tested with 0 and 2
- closures (0 or more tensor closures), tested with 0 and 2
- nested_level (no nesting or level-2 nested cond)
What this test doesn't cover:
- pred: symbolic boolean expression as predicate
- true_fn/false_fn: that mutates indermediate tensors
- operands: non-tensor operands such as float, int
- closures: nn_module attribute closures, python constant closures
- nested_level: 3+
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110727
Approved by: https://github.com/zou3519