Fix split module interaction with dead code (#104554)
Summary:
This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.
This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function
Unit Test Added:
test_split_module_dead_code
A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
def forward(self, x):
output = x * 2 # we want this
dead_line = x + 2 # this is dead
return output
```
Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```
After:
```
class GraphModule(torch.nn.Module):
def forward(self, x):
# No stacktrace found for following nodes
submod_2 = self.submod_2(x)
submod_1 = self.submod_1(x); x = None
return submod_1
class GraphModule(torch.nn.Module):
def forward(self, x):
# No stacktrace found for following nodes
add = x + 2; x = None
return None
class GraphModule(torch.nn.Module):
def forward(self, x):
# No stacktrace found for following nodes
mul = x * 2; x = None
return mul
```
Submod 2 is correctly extracted
Test Plan: Tested with new unit test
Differential Revision: D47196732
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104554
Approved by: https://github.com/yf225