[pytorch] correct input size check for GroupNorm (#33008)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33008
Corrects D19373507 to allow valid use cases that fail now. Multiplies batch size by the number of elements in a group to get the correct number of elements over which statistics are computed.
**Details**:
The current implementation disallows GroupNorm to be applied to tensors of shape e.g. `(1, C, 1, 1)` to prevent cases where statistics are computed over 1 element and thus result in a tensor filled with zeros.
However, in GroupNorm the statistics are calculated across channels. So in case where one has an input tensor of shape `(1, 256, 1, 1)` for `GroupNorm(32, 256)`, the statistics will be computed over 8 elements and thus be meaningful.
One use case is [Atrous Spatial Pyramid Pooling (ASPPPooling)](https://github.com/pytorch/vision/blob/791c172a337d98012018f98ffde93b1020ba3ed5/torchvision/models/segmentation/deeplabv3.py#L50), where GroupNorm could be used in place of BatchNorm [here](https://github.com/pytorch/vision/blob/791c172a337d98012018f98ffde93b1020ba3ed5/torchvision/models/segmentation/deeplabv3.py#L55). However, now this is prohibited and results in failures.
Proposed solution consists in correcting the computation of the number of elements over which statistics are computed. The number of elements per group is taken into account in the batch size.
Test Plan: check that existing tests pass
Reviewed By: fmassa
Differential Revision: D19723407
fbshipit-source-id: c85c244c832e6592e9aedb279d0acc867eef8f0c