pytorch
c58709b7 - Helper function for skipping module parameter / buffer initialization (#57555)

Commit
3 years ago
Helper function for skipping module parameter / buffer initialization (#57555) Summary: This PR introduces a helper function named `torch.nn.utils.skip_init()` that accepts a module class object + `args` / `kwargs` and instantiates the module while skipping initialization of parameter / buffer values. See discussion at https://github.com/pytorch/pytorch/issues/29523 for more context. Example usage: ```python import torch m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) print(m.weight) m2 = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1, device='cuda') print(m2.weight) m3 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=5, out_features=1) print(m3.weight) ``` ``` Parameter containing: tensor([[-3.3011e+28, 4.5915e-41, -3.3009e+28, 4.5915e-41, 0.0000e+00]], requires_grad=True) Parameter containing: tensor([[-2.5339e+27, 4.5915e-41, -2.5367e+27, 4.5915e-41, 0.0000e+00]], device='cuda:0', requires_grad=True) Parameter containing: tensor([[1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]], requires_grad=True) ``` Bikeshedding on the name / namespace is welcome, as well as comments on the design itself - just wanted to get something out there for discussion. Pull Request resolved: https://github.com/pytorch/pytorch/pull/57555 Reviewed By: zou3519 Differential Revision: D28640613 Pulled By: jbschlosser fbshipit-source-id: 5654f2e5af5530425ab7a9e357b6ba0d807e967f
Author
Parents
Loading