pytorch
cd813f16 - Add functional api for `nn.Module` (#61447)

Commit
3 years ago
Add functional api for `nn.Module` (#61447) Summary: Fixes https://github.com/pytorch/pytorch/issues/58839 After discussing with albanD he proposed this simple design. Let's iterate over the idea here :). Thanks. The main point that this PR does is to use reparametrization to be reverted at the end of the functional call. This allows us to have the original model with its status unchanged, also in this scenario the module is created without parameters so this will hard error if not all parameters are specified when the forward pass is done. ``` python import torch import torch.nn.utils._stateless class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) def forward(self, x): return self.l1(x) mod = MyModule() print('weight before', mod.l1.weight) x = torch.rand((1, 1)) parameters = {"l1.weight": torch.nn.Parameter(torch.tensor([[1.0]])), "l1.bias": torch.nn.Parameter(torch.tensor([0.0]))} res = torch.nn.utils._stateless.functional_call(mod, parameters, x) print('Functional call input ', x, ' and result ', res) print('weight after', mod.l1.weight) ``` Output ``` weight before Parameter containing: tensor([[-0.4419]], requires_grad=True) Functional call input tensor([[0.3531]]) and result tensor([[0.3531]], grad_fn=<AddmmBackward>) weight after Parameter containing: tensor([[-0.4419]], requires_grad=True) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/61447 Reviewed By: soulitzer Differential Revision: D31082765 Pulled By: albanD fbshipit-source-id: ba814d0f9162fb39c59989ca9a8efe160405ba76
Author
Emilio Castillo
Parents
Loading