Fix perfornance issue of GroupNorm on CUDA when feature map is small. (#46170)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46170
Fix perfornance issue of GroupNorm on CUDA when feature map is small.
Benchmark script:
```
import torch
import torch.nn.functional as F
from timeit import Timer
norm = torch.nn.GroupNorm(8, 512).cuda()
num = 5000
sizes = [(1024, 512, 14, 14), (1024, 512, 7, 7), (1024, 512)]
def forward(x):
_ = norm(x)
torch.cuda.synchronize()
def backward(y, grad):
y.backward(grad, retain_graph=True)
torch.cuda.synchronize()
if __name__ == "__main__":
# warm up
x = torch.rand(*(sizes[0]), dtype=torch.float,
device="cuda", requires_grad=True)
for _ in range(100):
forward(x)
for size in sizes:
x = torch.rand(*size, dtype=torch.float,
device="cuda", requires_grad=True)
t = Timer("forward(x)", "from __main__ import forward, x")
print(f"size = {size}:")
t1 = t.timeit(num) / num * 1e6
print(f"avg_forward_time = {t1}us")
y = norm(x)
grad = torch.randn_like(y)
t = Timer("backward(y, grad)", "from __main__ import backward, y, grad")
t2 = t.timeit(num) / num * 1e6
print(f"avg_backward_time = {t2}us")
```
Benchmark result before this Diff:
```
size = (1024, 512, 14, 14):
avg_forward_time = 1636.729855206795us
avg_backward_time = 5488.682465581223us
size = (1024, 512, 7, 7):
avg_forward_time = 465.88476160541177us
avg_backward_time = 3129.9425506033003us
size = (1024, 512):
avg_forward_time = 96.90486900508404us
avg_backward_time = 2319.4099438143894us
```
Benchmark result after this Diff:
```
size = (1024, 512, 14, 14):
avg_forward_time = 1635.6191572034732us
avg_backward_time = 4140.7730475999415us
size = (1024, 512, 7, 7):
avg_forward_time = 463.6513736099005us
avg_backward_time = 1641.7451039887965us
size = (1024, 512):
avg_forward_time = 66.59087920561433us
avg_backward_time = 128.6882139975205us
```
Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "GroupNorm"
Reviewed By: hl475, houseroad
Differential Revision: D24242738
fbshipit-source-id: b52c82d7b6e47855c48fa8ceacd0c55d03bb92d5