Optimize min and max(reduce_dim) performance on CPU (#34875)
Summary:
This PR is about improve min and max(reduce_dim) performance on CPU.
Test script:
```
import torch
import torch.nn as nn
import time
torch.manual_seed(0)
def _time():
return time.time()
device = "cpu"
torch.set_num_threads(1)
#warm up
for n in [10, 200]:
# contiguous
# input = torch.randn(n, n, n, requires_grad=False, device=device)
# discontiguous
input = torch.randn(n, 2*n, n, requires_grad=False, device=device)[:, :n, :]
for dim in range(input.dim()):
for i in range(1000):
output = input.min(dim)
#output = input.max(dim)
for n in [10, 200]:
# contiguous input
# input = torch.randn(n, n, n, requires_grad=False, device=device)
# discontiguous
input = torch.randn(n, 2*n, n, requires_grad=False, device=device)[:, :n, :]
for dim in range(input.dim()):
fwd_t = 0
for i in range(10000):
t1 = _time()
output = input.min(dim)
#output = input.max(dim)
t2 = _time()
fwd_t = fwd_t + (t2 -t1)
fwd_avg = fwd_t / 10000 * 1000
print("size = (%d, %d, %d); reduce dim=%d; compute time is %.4f(ms)" % (n, n, n, dim, fwd_avg))
```
Test device: **skx-8180**.
### Contiguous case.
- num_threads = 56
| Bef(ms) | Bef(ms) | Bef(ms) | Aft(ms) | Aft(ms) | Aft(ms)
-- | -- | -- | -- | -- | -- | --
size | dim=0 | dim=1 | dim=2 | dim=0 | dim=1 | dim=2
n=10 | 0.0243 | 0.0243 | 0.0244 | 0.0063 | 0.0065 | 0.0063
n=200 | 0.9615 | 0.9453 | 0.7772 | 0.2937 | 0.2675 | 0.2607
- num_threads = 1
| Bef(ms) | Bef(ms) | Bef(ms) | Aft(ms) | Aft(ms) | Aft(ms)
-- | -- | -- | -- | -- | -- | --
size | dim=0 | dim=1 | dim=2 | dim=0 | dim=1 | dim=2
n=10 | 0.0126 | 0.0126 | 0.0114 | 0.0062 | 0.0065 | 0.0064
n=200 | 32.1276 | 33.3489 | 29.0757 | 8.0556 | 7.0188 | 6.5014
### Discontiguous case.
- num_threads = 56
| Bef(ms) | Bef(ms) | Bef(ms) | Aft(ms) | Aft(ms) | Aft(ms)
-- | -- | -- | -- | -- | -- | --
size | dim=0 | dim=1 | dim=2 | dim=0 | dim=1 | dim=2
n=10 | 0.0106 | 0.0115 | 0.0131 | 0.0063 | 0.0066 | 0.0065
n=200 | 14.652 | 15.3496 | 9.8153 | 0.2946 | 0.2708 | 0.267
- num_threads = 1
| Bef(ms) | Bef(ms) | Bef(ms) | Aft(ms) | Aft(ms) | Aft(ms)
-- | -- | -- | -- | -- | -- | --
size | dim=0 | dim=1 | dim=2 | dim=0 | dim=1 | dim=2
n=10 | 0.0108 | 0.0116 | 0.0132 | 0.0058 | 0.0062 | 0.0061
n=200 | 12.5132 | 13.0785 | 9.6738 | 8.3733 | 7.3051 | 6.4566
https://github.com/pytorch/pytorch/issues/24671 and https://github.com/pytorch/pytorch/issues/24672 are also fixed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34875
Differential Revision: D20605596
Pulled By: ngimel
fbshipit-source-id: 08fd4dacd1db63309123d7ec5942a4b8a0071896