pytorch
aa4ed332 - Improve torch.cond useability: Return UserError with actionable error messages (#98909)

Commit
1 year ago
Improve torch.cond useability: Return UserError with actionable error messages (#98909) It's part of the effort to improve PT2 Export UX. This PR is to improve the usability of `torch.cond()` by separating user errors from the dynamo internal errors. By definition, user error means the usage of `torch.cond()` violates the restrictions of this API therefore needs users to take action and fix the error. In this notebook N3363227 we discovered a bunch of limitations of using `torch.cond(pred, true_fn, false_fn, operands)`. In summary, the limitations can be categorized as: - predicate restriction (`pred`) - operands restriction (`operands`) - branch restriction (`true_fn` & `false_fn`) The error message will be more accurate about where the (user) error is from and more actionable for users to fix it. For example, `operands` must be a list of tensors and the signature of `true_fn` and `false_fn` must match with the `operands`. If the operands contains non-tensor types, user will see error message like: ``` torch._dynamo.exc.UserError: Expected a list of tensors but got ["<class 'torch.Tensor'>", "<class 'float'>"] from user code: File "~/pytorch/test/dynamo/test_export.py", line 2504, in f_non_tensor_operands return cond(True, lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a]) ``` If the signature of the branch function doesn't match with `operands`, user will see error message like: ``` torch._dynamo.exc.UserError: too many positional arguments. func = 'false_fn' ~/pytorch/test/dynamo/test_export.py:2514, args = [<class 'torch.Tensor'>, <class 'torch.Tensor'>], kwargs = {} ``` Or if the tensor returned from user defined branches has different metadata, e.g. shapes, dtypes, etc., user will see error message like: ``` TypeError: Expected each tensor to have same metadata but got: cond_true_0 returns TensorMetadata(shape=torch.Size([2, 1]), dtype=torch.int64, requires_grad=False, stride=(1, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={}) cond_false_0 returns TensorMetadata(shape=torch.Size([1]), dtype=torch.float32, requires_grad=False, stride=(1,), memory_format=torch.contiguous_format, is_quantized=False, qparams={}) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/98909 Approved by: https://github.com/jansel
Author
Committer
Parents
Loading