[FX] get the correct error message (#47108)
Summary:
Currently, code like
```
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.W = torch.nn.Parameter(torch.randn(5))
def forward(self, x):
return torch.dot(self.W, x)
mod = Test()
print(fx.symbolic_trace(Test())(5))
```
gives an error like the below, which does not show the actual code that throws the error.
```
Traceback (most recent call last):
File "t.py", line 20, in <module>
print(fx.symbolic_trace(Test())(5))
File "/home/chilli/fb/pytorch/torch/nn/modules/module.py", line 744, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 191, in debug_forward
return src_forward(self, *args, **kwargs)
File "<eval_with_key_0>", line 5, in forward
TypeError: dot(): argument 'tensor' (position 2) must be Tensor, not int
```
This is particularly annoying when your function has already been transformed several times.
So, the really annoying thing is that the error clearly has the requisite information in `exception.__traceback__` - it just isn't printing it.
I think the right way of doing this is simply replacing `sys.excepthook`. This appears to be the standard way to modify exception messages.
**Scratch the below**
The 2 methods in the PR right now are:
1. Just prepend the final part of the traceback to the beginning of your error message. Looks like
```
Traceback (most recent call last):
File "t.py", line 20, in <module>
print(fx.symbolic_trace(Test())(5))
File "/home/chilli/fb/pytorch/torch/nn/modules/module.py", line 744, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 197, in debug_forward
raise e
File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 192, in debug_forward
return src_forward(self, *args, **kwargs)
File "<eval_with_key_0>", line 5, in forward
TypeError: File "<eval_with_key_0>", line 5, in forward
dot_1 = torch.dot(w, x)
dot(): argument 'tensor' (position 2) must be Tensor, not int
```
2. Use the `from exception` feature in Python. Looks like
```
Traceback (most recent call last):
File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 192, in debug_forward
return src_forward(self, *args, **kwargs)
File "<eval_with_key_0>", line 5, in forward
TypeError: File "<eval_with_key_0>", line 5, in forward
dot_1 = torch.dot(w, x)
dot(): argument 'tensor' (position 2) must be Tensor, not int
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "t.py", line 20, in <module>
print(fx.symbolic_trace(Test())(5))
File "/home/chilli/fb/pytorch/torch/nn/modules/module.py", line 744, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/chilli/fb/pytorch/torch/fx/graph_module.py", line 197, in debug_forward
raise Exception(last_tb) from e
Exception: File "<eval_with_key_0>", line 5, in forward
dot_1 = torch.dot(w, x)
```
I think the first one looks better, but it's pretty hacky since we're shoving the traceback in the message.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47108
Reviewed By: jamesr66a
Differential Revision: D24751019
Pulled By: Chillee
fbshipit-source-id: 83e6ed0165f98632a77c73de75504fd6263fff40