Remove with_traceback(None) in wrapped_call to show the root cause error
Before:
```
Traceback (most recent call last):
File "/Users/pbelevich/PycharmProjects/PiPPy/test/t5_test.py", line 37, in <module>
t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 251, in forward
return self.executor.run(*executor_args)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 155, in run
return super().run(*args, initial_env=initial_env)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 121, in run
self.env[node] = self.run_node(node)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 148, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 170, in call_module
return super().call_module(target, args, kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 265, in call_module
return submod(*args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e.with_traceback(None)
AttributeError: 'NoneType' object has no attribute 'dtype'
```
After:
```
Traceback (most recent call last):
File "/Users/pbelevich/PycharmProjects/PiPPy/test/t5_test.py", line 37, in <module>
t5_pipe_output = t5_pipe(input_ids=t5_input, decoder_attention_mask=None, decoder_input_ids=decoder_input_ids)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 251, in forward
return self.executor.run(*executor_args)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 155, in run
return super().run(*args, initial_env=initial_env)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 121, in run
self.env[node] = self.run_node(node)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 148, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/Users/pbelevich/PycharmProjects/PiPPy/pippy/IR.py", line 170, in call_module
return super().call_module(target, args, kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/interpreter.py", line 265, in call_module
return submod(*args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
return cls_call(self, *args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
return cls_call(self, *args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
return cls_call(self, *args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
return cls_call(self, *args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 620, in wrapped_call
return cls_call(self, *args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 630, in wrapped_call
raise e
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/fx/graph_module.py", line 622, in wrapped_call
return super(cls, self).__call__(*args, **kwargs)
File "/Users/pbelevich/miniconda3/envs/PiPPy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "<eval_with_key>.42", line 74, in forward
File "/Users/pbelevich/PycharmProjects/pbelevich-transformers/src/transformers/utils/fx.py", line 180, in wrapper
return func(*args, **kwargs)
File "/Users/pbelevich/PycharmProjects/pbelevich-transformers/src/transformers/modeling_utils.py", line 256, in create_extended_attention_mask_for_decoder
causal_mask = causal_mask.to(attention_mask.dtype)
AttributeError: 'NoneType' object has no attribute 'dtype'
```
The last lines of stack trace show where the problem is
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74655
Approved by: https://github.com/ansley, https://github.com/rohan-varma