Skip None submodule during JIT-tracing (#49765)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49765
Some PyTorch module can have None as submodule, which causes the following error in JIT-tracing:
Repro script:
```
import torch
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = torch.nn.Linear(3, 4)
self.submod = None
def forward(self, inputs):
return inputs
m = TestModule()
tm = torch.jit.trace(m, torch.tensor(1.))
```
Error:
```
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/data/miniconda3/envs/master_nightly/lib/python3.7/site-packages/torch/jit/_trace.py", line 742, in trace
_module_class,
File "/data/miniconda3/envs/master_nightly/lib/python3.7/site-packages/torch/jit/_trace.py", line 928, in trace_module
module = make_module(mod, _module_class, _compilation_unit)
File "/data/miniconda3/envs/master_nightly/lib/python3.7/site-packages/torch/jit/_trace.py", line 560, in make_module
return _module_class(mod, _compilation_unit=_compilation_unit)
File "/data/miniconda3/envs/master_nightly/lib/python3.7/site-packages/torch/jit/_trace.py", line 1039, in __init__
submodule, TracedModule, _compilation_unit=None
File "/data/miniconda3/envs/master_nightly/lib/python3.7/site-packages/torch/jit/_trace.py", line 560, in make_module
return _module_class(mod, _compilation_unit=_compilation_unit)
File "/data/miniconda3/envs/master_nightly/lib/python3.7/site-packages/torch/jit/_trace.py", line 988, in __init__
assert isinstance(orig, torch.nn.Module)
AssertionError
```
This pull request changes the JIT-tracing logic to skip the None submodule when tracing.
Test Plan: `buck test mode/dev //caffe2/test:jit -- test_trace_skip_none_submodule`
Reviewed By: wanchaol
Differential Revision: D25670948
fbshipit-source-id: 468f42f5ddbb8fd3de06d0bc224dc67bd7172358