[fx]added prototype of to_folder (#47544)
Summary:
What this does is that given a `FxModule foo`, you can call `foo.to_folder('foo_folder', 'Foo')` and dump the current FX module into runnable Python code.
That is
```
foo = <fxModule>
foo = foo.to_folder('bar', 'Foo')
from bar import Foo
foo2 = Foo()
forall x, foo2(x) == Foo(x)
```
This has several use cases, largely lifted from jamesr66a's doc here: https://fb.quip.com/U6KHAFaP2cWa (FB-internal).
1. As we apply more heavy-weight function transformations with FX, figuring out what's going on can be quite a difficult experience. In particular, things that can typically be used for debugging (like `print` or `import pdb; pdb.set_trace()`) no longer work. This is particularly necessary if you're using a FX transform like `grad` or `vmap. With this, you simply open up the dumped file, and add `print`/`pdb` statements wherever you'd like.
2. This also provides an immense amount of user control. Some potential use-cases:
- Let's say an existing FX transform has some bug, or generates suboptimal code. Instead of needing to modify that FX transform, writing another FX pass that fixes the suboptimal code, or simply giving up on FX, they can workaround it by simply modifying the resulting code themselves.
- This allows users to check in their FX modules into source control.
- You could even imagine using this as part of some code-gen type workflow, where you write a function, `vmap` it to get the function you actually want, and then simply copy the output of the `vmap` function without needing FX at all in the final code.
An example:
```python
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.W = torch.nn.Parameter(torch.randn(2))
self.linear = nn.Linear(2, 2)
self.attr = torch.randn(2)
self.attr2 = torch.randn(2)
def forward(self, x):
return self.linear(self.W + (self.attr + self.attr2) + x)
mod = fx.symbolic_trace(Test())
mod.to_folder('foo', 'Foo')
```
results in
```python
import torch
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
state_dict = torch.load('foo/state_dict.pt')
self.linear = torch.load('foo/linear.pt') # Linear(in_features=2, out_features=2, bias=True)
self.__tensor_constant0 = state_dict['__tensor_constant0']
self.W = torch.nn.Parameter(state_dict['W'])
def forward(self, x):
w = self.W
tensor_constant0 = self.__tensor_constant0
add_1 = w + tensor_constant0
add_2 = add_1 + x
linear_1 = self.linear(add_2)
return linear_1
```
Some current issues:
1. How do you actually ... save things like modules or parameters? I don't think FX is in the business of tracking initializations and such. Thus, the only way I see to do it is to dump the parameters/modules as blobs, and then load them in the generated initialization. This is a somewhat subpar user experience, and perhaps prevents it from being in some use cases (ie: you would need to check in the blobs into source control to save the model).
2. Currently, the only "atomic" modules we have are those in `torch.nn`. However, if we want to allow flexibility in this, and for example, allow "atomic" modules that are user-defined, then it's not clear how to allow those to be dumped in a way that we can then load elsewhere.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47544
Reviewed By: jamesr66a
Differential Revision: D25232917
Pulled By: Chillee
fbshipit-source-id: fd2b61a5f40e614fc94256a2957ed1d57fcf5492