pytorch
37324673 - [FX] get the correct error message (#47108)

Commit
4 years ago
[FX] get the correct error message (#47108) Summary: Currently, code like ``` class Test(nn.Module): def __init__(self): super(Test, self).__init__() self.W = torch.nn.Parameter(torch.randn(5)) def forward(self, x): return torch.dot(self.W, x) mod = Test() print(fx.symbolic_trace(Test())(5)) ``` gives an error like the below, which does not show the actual code that throws the error. ``` Traceback (most recent call last): File "t.py", line 20, in <module> print(fx.symbolic_trace(Test())(5)) File "/home/chilli/fb/pytorch/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(*input, **kwargs) File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 191, in debug_forward return src_forward(self, *args, **kwargs) File "<eval_with_key_0>", line 5, in forward TypeError: dot(): argument 'tensor' (position 2) must be Tensor, not int ``` This is particularly annoying when your function has already been transformed several times. So, the really annoying thing is that the error clearly has the requisite information in `exception.__traceback__` - it just isn't printing it. I think the right way of doing this is simply replacing `sys.excepthook`. This appears to be the standard way to modify exception messages. **Scratch the below** The 2 methods in the PR right now are: 1. Just prepend the final part of the traceback to the beginning of your error message. Looks like ``` Traceback (most recent call last): File "t.py", line 20, in <module> print(fx.symbolic_trace(Test())(5)) File "/home/chilli/fb/pytorch/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(*input, **kwargs) File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 197, in debug_forward raise e File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 192, in debug_forward return src_forward(self, *args, **kwargs) File "<eval_with_key_0>", line 5, in forward TypeError: File "<eval_with_key_0>", line 5, in forward dot_1 = torch.dot(w, x) dot(): argument 'tensor' (position 2) must be Tensor, not int ``` 2. Use the `from exception` feature in Python. Looks like ``` Traceback (most recent call last): File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 192, in debug_forward return src_forward(self, *args, **kwargs) File "<eval_with_key_0>", line 5, in forward TypeError: File "<eval_with_key_0>", line 5, in forward dot_1 = torch.dot(w, x) dot(): argument 'tensor' (position 2) must be Tensor, not int The above exception was the direct cause of the following exception: Traceback (most recent call last): File "t.py", line 20, in <module> print(fx.symbolic_trace(Test())(5)) File "/home/chilli/fb/pytorch/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(*input, **kwargs) File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 197, in debug_forward raise Exception(last_tb) from e Exception: File "<eval_with_key_0>", line 5, in forward dot_1 = torch.dot(w, x) ``` I think the first one looks better, but it's pretty hacky since we're shoving the traceback in the message. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47108 Reviewed By: jamesr66a Differential Revision: D24751019 Pulled By: Chillee fbshipit-source-id: 83e6ed0165f98632a77c73de75504fd6263fff40
Author
Parents
Loading