Add functional api for `nn.Module` (#61447)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/58839
After discussing with albanD he proposed this simple design.
Let's iterate over the idea here :).
Thanks.
The main point that this PR does is to use reparametrization to be reverted at the end of the functional call.
This allows us to have the original model with its status unchanged, also in this scenario the module is created without parameters so this will hard error if not all parameters are specified when the forward pass is done.
``` python
import torch
import torch.nn.utils._stateless
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
def forward(self, x):
return self.l1(x)
mod = MyModule()
print('weight before', mod.l1.weight)
x = torch.rand((1, 1))
parameters = {"l1.weight": torch.nn.Parameter(torch.tensor([[1.0]])),
"l1.bias": torch.nn.Parameter(torch.tensor([0.0]))}
res = torch.nn.utils._stateless.functional_call(mod, parameters, x)
print('Functional call input ', x, ' and result ', res)
print('weight after', mod.l1.weight)
```
Output
```
weight before Parameter containing:
tensor([[-0.4419]], requires_grad=True)
Functional call input tensor([[0.3531]]) and result tensor([[0.3531]], grad_fn=<AddmmBackward>)
weight after Parameter containing:
tensor([[-0.4419]], requires_grad=True)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61447
Reviewed By: soulitzer
Differential Revision: D31082765
Pulled By: albanD
fbshipit-source-id: ba814d0f9162fb39c59989ca9a8efe160405ba76