Simplify convolution double backward gradInput formulas (#54840)
Summary:
Currently in convolution double backward grad of input is computed as `convT(ggW, gO.T)`. Notice how first argument is, in fact, of the size that convolution weight has, and second is of the size of gradOutput, which is an inverse order compared to how convolutions are regularly called, and sizes are far from what cudnn heuristics is trained for and what cudnn is guaranteed to have efficient kernels for. This takes cudnn 8 to some dark places, calling kernels that take 20-100 s. But, luckily for us, convT is a commutative operation (unlike conv), so convT(ggW, gO) is actually the same as convT(gO, ggW), modulo some transposes because of conventions around the weight size, so we can use convT(gO, ggW). As an added bonus, we don't need a special branch for groups with this formulation.
For the following pretty standard convolution,
- cudnn 7.6+old formulation takes 7.5 ms for double backward,
- cudnn 8 + old formulation takes ~40 s,
- cudnn 8 + new formulation is 1.8 ms with benchmark enabled,
- cudnn 8 + new formulation is 4 ms with benchmark disabled,
benchmarking script is below:
```
import torch
import time
#torch.backends.cudnn.benchmark=True
def ggI(conv, inp):
out = conv(inp)
grads = torch.autograd.grad(out, conv.weight, torch.rand_like(out), create_graph=True, retain_graph=True)
torch.cuda.synchronize()
start = time.time()
grads[0].backward(torch.rand_like(grads[0]))
torch.cuda.synchronize()
print("db time: ", time.time()-start)
return inp.grad
conv = torch.nn.Conv2d(512,256,kernel_size=3, padding=1, groups=2).cuda()
inp = torch.randn(1,512,128,128, device="cuda", requires_grad=True)
for _ in range(20):
ggI(conv, inp)
torch.cuda.synchronize()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54840
Reviewed By: mruberry
Differential Revision: D27384866
Pulled By: ngimel
fbshipit-source-id: c6c875776a9801a0a2cd2f34f8ec39d0fcd59df8
Author
Natalia Gimelshein