Customize traceback for calls to symbolically-traced code (#51648)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51648
The following code will throw during the call to `traced(5)`:
```python
class M(nn.Module):
def __init__(self):
super(M, self).__init__()
self.W = torch.nn.Parameter(torch.randn(5))
def forward(self, x):
return torch.dot(self.W, x)
traced = fx.symbolic_trace(M())
traced(5)
```
Traceback before:
```
Traceback (most recent call last):
File "test/tinytest.py", line 26, in <module>
traced(5)
File "/home/ansley/local/pytorch/torch/fx/graph_module.py", line 338, in wrapped_call
return self._cls_call(self, *args, **kwargs)
File "/home/ansley/local/pytorch/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "<eval_with_key_0>", line 4, in forward
TypeError: dot(): argument 'tensor' (position 2) must be Tensor, not int
```
Traceback after:
```
Traceback (most recent call last):
File "/home/ansley/local/pytorch/torch/fx/graph_module.py", line 338, in wrapped_call
return torch.nn.Module.__call__(self, *args, **kwargs)
File "/home/ansley/local/pytorch/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "<eval_with_key_1>", line 4, in forward
dot_1 = torch.dot(w, x); w = x = None
TypeError: dot(): argument 'tensor' (position 2) must be Tensor, not int
Call using an FX-traced Module, line 4 of the traced Module’s generated forward function:
w = self.W
dot_1 = torch.dot(w, x); w = x = None
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
relu_1 = dot_1.relu(); dot_1 = None
return relu_1
```
(Note that the same `TypeError` is thrown despite modifying the traceback.)
Test Plan: Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D26424005
Pulled By: ansley
fbshipit-source-id: 368f46ba81fb3111bd09654825bb2ac5595207d1