Helper function for skipping module parameter / buffer initialization (#57555)
Summary:
This PR introduces a helper function named `torch.nn.utils.skip_init()` that accepts a module class object + `args` / `kwargs` and instantiates the module while skipping initialization of parameter / buffer values. See discussion at https://github.com/pytorch/pytorch/issues/29523 for more context. Example usage:
```python
import torch
m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
print(m.weight)
m2 = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1, device='cuda')
print(m2.weight)
m3 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=5, out_features=1)
print(m3.weight)
```
```
Parameter containing:
tensor([[-3.3011e+28, 4.5915e-41, -3.3009e+28, 4.5915e-41, 0.0000e+00]],
requires_grad=True)
Parameter containing:
tensor([[-2.5339e+27, 4.5915e-41, -2.5367e+27, 4.5915e-41, 0.0000e+00]],
device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
requires_grad=True)
```
Bikeshedding on the name / namespace is welcome, as well as comments on the design itself - just wanted to get something out there for discussion.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57555
Reviewed By: zou3519
Differential Revision: D28640613
Pulled By: jbschlosser
fbshipit-source-id: 5654f2e5af5530425ab7a9e357b6ba0d807e967f