Dirac init compatibility with group convolutions (#32825)
Summary:
Initializing weights of group-conv with init.dirac_, and applying, previously resulted in an output that makes no sense:
```
x = torch.randn([1, 3, 3, 3])
print('input:\n', x)
conv_layer = torch.nn.Conv2d(3, 3, 3, padding=1, groups=3, bias=False)
torch.nn.init.dirac_(conv_layer.weight.data)
print('\noutput (before this PR):\n',conv_layer(x))
input:
tensor([[[[ 0.5369, -1.1428, 0.1031],
[ 0.4638, -0.0854, -0.6553],
[ 0.8321, -2.5926, -0.3214]],
[[-0.2289, -0.0895, 0.4407],
[ 1.2309, -1.2096, -1.5216],
[-0.1798, 1.1694, 0.3469]],
[[ 0.1905, 0.8095, 0.5490],
[-0.4525, -0.4284, -0.1141],
[ 1.1857, -0.9246, -0.5119]]]])
output (before this PR):
tensor([[[[ 0.5369, -1.1428, 0.1031],
[ 0.4638, -0.0854, -0.6553],
[ 0.8321, -2.5926, -0.3214]],
[[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000]]]], grad_fn=<MkldnnConvolutionBackward>)
````
This PR allows introducing groups to the initialization:
```
torch.nn.init.dirac_(conv_layer.weight.data, groups=3)
print('output (after this PR):\n', conv_layer(x))
output (after this PR):
tensor([[[[ 0.5369, -1.1428, 0.1031],
[ 0.4638, -0.0854, -0.6553],
[ 0.8321, -2.5926, -0.3214]],
[[-0.2289, -0.0895, 0.4407],
[ 1.2309, -1.2096, -1.5216],
[-0.1798, 1.1694, 0.3469]],
[[ 0.1905, 0.8095, 0.5490],
[-0.4525, -0.4284, -0.1141],
[ 1.1857, -0.9246, -0.5119]]]], grad_fn=<MkldnnConvolutionBackward>)
```
When out_channels is different than input_channels, it does the natural thing which is applying identity in each group separately:
```
x = torch.randn([1, 2, 3, 3])
print('input:\n', x)
conv_layer = torch.nn.Conv2d(2, 4, 3, padding=1, groups=2, bias=False)
torch.nn.init.dirac_(conv_layer.weight.data, groups=2)
print('\noutput:\n', conv_layer(x))
input:
tensor([[[[ 1.2205, -0.6608, 0.8640],
[-0.5464, 1.1288, 1.4726],
[-0.6693, 0.4000, -1.7613]],
[[-0.8760, -0.8814, -0.4705],
[ 0.6283, -0.5943, 0.6873],
[-0.6852, 1.4723, 0.3325]]]])
output:
tensor([[[[ 1.2205, -0.6608, 0.8640],
[-0.5464, 1.1288, 1.4726],
[-0.6693, 0.4000, -1.7613]],
[[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000]],
[[-0.8760, -0.8814, -0.4705],
[ 0.6283, -0.5943, 0.6873],
[-0.6852, 1.4723, 0.3325]],
[[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000]]]], grad_fn=<MkldnnConvolutionBackward>)
```
Argument 'groups' defaults to 1 so it is backward compatible.
Tests are modified to include cases of with groups>1 but also contain groups=1 cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32825
Differential Revision: D19859926
Pulled By: vincentqb
fbshipit-source-id: 9dfdd24471ff14d79c442dfd28c1891aff812fdf