pytorch
3062e267 - [cond] Add more tests for valid inputs of cond (#110727)

Commit
1 year ago
[cond] Add more tests for valid inputs of cond (#110727) This PR adds a parametrized test for cond. It tests cond can be traced with valid inputs. Specifically valid inputs is combination of: - pred (python boolean, boolean tensor, int tensor, scalar tensor) - true_fn/false_fn (func, obj, nn_module) - Operands (0 or more tensor inputs), tested with 0 and 2 - closures (0 or more tensor closures), tested with 0 and 2 - nested_level (no nesting or level-2 nested cond) What this test doesn't cover: - pred: symbolic boolean expression as predicate - true_fn/false_fn: that mutates indermediate tensors - operands: non-tensor operands such as float, int - closures: nn_module attribute closures, python constant closures - nested_level: 3+ Pull Request resolved: https://github.com/pytorch/pytorch/pull/110727 Approved by: https://github.com/zou3519
Author
Committer
Parents
Loading