pytorch
4a0f6e6c - report an error if num_channels is not divisible by num_groups for nn.GroupNorm

Commit
3 years ago
report an error if num_channels is not divisible by num_groups for nn.GroupNorm For a GroupNorm module, if num_channels is not divisible by num_groups, we need to report an error when defining a module other than at the running step. example: ``` import torch m = torch.nn.GroupNorm(5, 6) x = torch.randn(1, 6, 4, 4) y = m(x) ``` before: ``` Traceback (most recent call last): File "group_norm_test.py", line 8, in <module> y = m(x) File "/home/xiaobinz/miniconda3/envs/pytorch_mater/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1111, in _call_impl return forward_call(*input, **kwargs) File "/home/xiaobinz/miniconda3/envs/pytorch_mater/lib/python3.7/site-packages/torch/nn/modules/normalization.py", line 271, in forward input, self.num_groups, self.weight, self.bias, self.eps) File "/home/xiaobinz/miniconda3/envs/pytorch_mater/lib/python3.7/site-packages/torch/nn/functional.py", line 2500, in group_norm return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) RuntimeError: Expected number of channels in input to be divisible by num_groups, but got input of shape [1, 6, 4, 4] and num_groups=5 ``` after: ``` Traceback (most recent call last): File "group_norm_test.py", line 6, in <module> m = torch.nn.GroupNorm(5, 6) File "/home/xiaobinz/miniconda3/envs/pytorch_test/lib/python3.7/site-packages/torch/nn/modules/normalization.py", line 251, in __init__ raise ValueError('num_channels must be divisible by num_groups') ``` This PR also update the doc of num_groups. Pull Request resolved: https://github.com/pytorch/pytorch/pull/74293 Approved by: https://github.com/jbschlosser
Author
Committer
Parents
Loading