[export] add sequential_split to prepare replacing set_grad_enabled with hop (#119732)
This pr is the 1/N pr of transforming the global state mutating ops such as torch._C.set_grad_enabled calls in pre-dispatch graph into a higher order op so that the graph becomes more functional. We make use of split_module to help us do the transformation.
This pr preserves the node.name in original module by adding a new kwarg `keep_original_node_name` to split_module.
For a graph looks like this:
```python
def forward(self, arg_0):
arg0_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
_set_grad_enabled = torch._C._set_grad_enabled(False)
add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
sub = torch.ops.aten.sub.Tensor(add_1, 1)
return pytree.tree_unflatten((add_1, sub), self._out_spec)
```
Before the change, split graph returns the following graphs and subgraphs (notice the change from `add` -> `add_tensor`, `sin` -> `sin_default`:
```python
def forward(self, arg_0):
arg0_1, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
submod_0 = self.submod_0(arg0_1); arg0_1 = None
submod_1 = self.submod_1(submod_0); submod_0 = None
submod_2 = self.submod_2(submod_1)
return pytree.tree_unflatten((submod_1, submod_2), self._out_spec)
# submod_0
def forward(self, arg0_1):
add_tensor = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
sin_default = torch.ops.aten.sin.default(add_tensor); add_tensor = None
sum_default = torch.ops.aten.sum.default(sin_default); sin_default = None
return sum_default
# submod_1
def forward(self, sum_1):
_set_grad_enabled = torch._C._set_grad_enabled(False)
add_tensor = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
return add_tensor
# submod_2
def forward(self, add_1):
_set_grad_enabled = torch._C._set_grad_enabled(True)
sub_tensor = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None
return sub_tensor
""")
```
After the change, the test produce the following graph, all the node names in original graph module are preserved in sub_modules.
```python
def forward(self, arg_0):
sub, = fx_pytree.tree_flatten_spec(([arg_0], {}), self._in_spec)
submod_0 = self.submod_0(sub); sub = None
submod_1 = self.submod_1(submod_0); submod_0 = None
submod_2 = self.submod_2(submod_1)
return pytree.tree_unflatten((submod_1, submod_2), self._out_spec)
# submod_0
def forward(self, arg0_1):
add = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
return sum_1
# submod_1
def forward(self, sum_1):
_set_grad_enabled = torch._C._set_grad_enabled(False)
add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
return add_1
# submod_2
def forward(self, add_1):
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
sub = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None
return sub
```
Note that currently, we call split_module on the graph after pre-dispatch aot. The difference is even larger if we `split_module` the graph module produced by dynamo, where all the original variables names in user program are preserved after dynamo but lost after `split_module` without this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119732
Approved by: https://github.com/tugsbayasgalan