Move min and max(reduce all) to Aten(CPU) (#33936)
Summary:
This PR is about port min and max(reduce all) to Aten.
Performance test script:
```
import torch
import timeit
torch.manual_seed(0)
#torch.set_num_threads(1)
device = "cpu"
print(f'device: {device}')
for op in ('max', 'min'):
for dtype in ('torch.double', 'torch.float', 'torch.int16', 'torch.int32', 'torch.int64'):
for n, t in [(20_000, 200000),
(200_000, 20000)]:
print(f'a.{op}(), numel() == {n} for {t} times, dtype={dtype}')
print(timeit.timeit(f'a.{op}()', setup=f'import torch; a =(torch.torch.randn({n}) * 100).to({dtype})', number=t))
```
Test device: **skx-8180, 2 sockets**
Before:
```
a.max(), numel() == 20000 for 200000 times, dtype=torch.double
2.773961597122252
a.max(), numel() == 200000 for 20000 times, dtype=torch.double
2.3256353894248605
a.max(), numel() == 20000 for 200000 times, dtype=torch.float
3.800648272037506
a.max(), numel() == 200000 for 20000 times, dtype=torch.float
3.31692426931113
a.max(), numel() == 20000 for 200000 times, dtype=torch.int16
2.735901520587504
a.max(), numel() == 200000 for 20000 times, dtype=torch.int16
2.2510280115529895
a.max(), numel() == 20000 for 200000 times, dtype=torch.int32
2.723656536079943
a.max(), numel() == 200000 for 20000 times, dtype=torch.int32
2.228839812800288
a.max(), numel() == 20000 for 200000 times, dtype=torch.int64
2.703160767443478
a.max(), numel() == 200000 for 20000 times, dtype=torch.int64
2.3175809988752007
a.min(), numel() == 20000 for 200000 times, dtype=torch.double
2.820106916129589
a.min(), numel() == 200000 for 20000 times, dtype=torch.double
2.325718787498772
a.min(), numel() == 20000 for 200000 times, dtype=torch.float
3.833602518774569
a.min(), numel() == 200000 for 20000 times, dtype=torch.float
3.316444822587073
a.min(), numel() == 20000 for 200000 times, dtype=torch.int16
2.7308286419138312
a.min(), numel() == 200000 for 20000 times, dtype=torch.int16
2.198460517451167
a.min(), numel() == 20000 for 200000 times, dtype=torch.int32
2.730219766497612
a.min(), numel() == 200000 for 20000 times, dtype=torch.int32
2.2268200274556875
a.min(), numel() == 20000 for 200000 times, dtype=torch.int64
2.7342184390872717
a.min(), numel() == 200000 for 20000 times, dtype=torch.int64
2.320415544323623
```
After:
```
a.max(), numel() == 20000 for 200000 times, dtype=torch.double
1.7767417253926396
a.max(), numel() == 200000 for 20000 times, dtype=torch.double
0.550495645031333
a.max(), numel() == 20000 for 200000 times, dtype=torch.float
1.1113408291712403
a.max(), numel() == 200000 for 20000 times, dtype=torch.float
0.44446005020290613
a.max(), numel() == 20000 for 200000 times, dtype=torch.int16
0.5246349424123764
a.max(), numel() == 200000 for 20000 times, dtype=torch.int16
0.47057845536619425
a.max(), numel() == 20000 for 200000 times, dtype=torch.int32
0.6597231412306428
a.max(), numel() == 200000 for 20000 times, dtype=torch.int32
0.40366593934595585
a.max(), numel() == 20000 for 200000 times, dtype=torch.int64
1.767227927222848
a.max(), numel() == 200000 for 20000 times, dtype=torch.int64
0.6187495030462742
a.min(), numel() == 20000 for 200000 times, dtype=torch.double
1.7881382443010807
a.min(), numel() == 200000 for 20000 times, dtype=torch.double
0.5440589748322964
a.min(), numel() == 20000 for 200000 times, dtype=torch.float
1.1090848250314593
a.min(), numel() == 200000 for 20000 times, dtype=torch.float
0.4293213738128543
a.min(), numel() == 20000 for 200000 times, dtype=torch.int16
0.5207074657082558
a.min(), numel() == 200000 for 20000 times, dtype=torch.int16
0.41422136034816504
a.min(), numel() == 20000 for 200000 times, dtype=torch.int32
0.6145811947062612
a.min(), numel() == 200000 for 20000 times, dtype=torch.int32
0.4172037309035659
a.min(), numel() == 20000 for 200000 times, dtype=torch.int64
1.7397673893719912
a.min(), numel() == 200000 for 20000 times, dtype=torch.int64
0.596766366623342
```
Single thread:
Before:
```
a.max(), numel() == 20000 for 200000 times, dtype=torch.double
2.5068740313872695
a.max(), numel() == 200000 for 20000 times, dtype=torch.double
2.234461876563728
a.max(), numel() == 20000 for 200000 times, dtype=torch.float
3.5549037409946322
a.max(), numel() == 200000 for 20000 times, dtype=torch.float
3.2497852174565196
a.max(), numel() == 20000 for 200000 times, dtype=torch.int16
2.493077039718628
a.max(), numel() == 200000 for 20000 times, dtype=torch.int16
2.171935741789639
a.max(), numel() == 20000 for 200000 times, dtype=torch.int32
2.469274105504155
a.max(), numel() == 200000 for 20000 times, dtype=torch.int32
2.273881389759481
a.max(), numel() == 20000 for 200000 times, dtype=torch.int64
2.5818942049518228
a.max(), numel() == 200000 for 20000 times, dtype=torch.int64
2.2394551979377866
a.min(), numel() == 20000 for 200000 times, dtype=torch.double
2.5894540259614587
a.min(), numel() == 200000 for 20000 times, dtype=torch.double
2.331936141476035
a.min(), numel() == 20000 for 200000 times, dtype=torch.float
3.590122046880424
a.min(), numel() == 200000 for 20000 times, dtype=torch.float
3.255849950015545
a.min(), numel() == 20000 for 200000 times, dtype=torch.int16
2.5205496419221163
a.min(), numel() == 200000 for 20000 times, dtype=torch.int16
2.168218174017966
a.min(), numel() == 20000 for 200000 times, dtype=torch.int32
2.658622432500124
a.min(), numel() == 200000 for 20000 times, dtype=torch.int32
2.3376982398331165
a.min(), numel() == 20000 for 200000 times, dtype=torch.int64
2.496626536361873
a.min(), numel() == 200000 for 20000 times, dtype=torch.int64
2.2504652086645365
```
After:
```
a.max(), numel() == 20000 for 200000 times, dtype=torch.double
1.9525171788409352
a.max(), numel() == 200000 for 20000 times, dtype=torch.double
1.6108122132718563
a.max(), numel() == 20000 for 200000 times, dtype=torch.float
1.2444602297618985
a.max(), numel() == 200000 for 20000 times, dtype=torch.float
0.7705567870289087
a.max(), numel() == 20000 for 200000 times, dtype=torch.int16
0.6575072864070535
a.max(), numel() == 200000 for 20000 times, dtype=torch.int16
0.13242999743670225
a.max(), numel() == 20000 for 200000 times, dtype=torch.int32
0.829406064003706
a.max(), numel() == 200000 for 20000 times, dtype=torch.int32
0.35575105529278517
a.max(), numel() == 20000 for 200000 times, dtype=torch.int64
1.6426756298169494
a.max(), numel() == 200000 for 20000 times, dtype=torch.int64
1.4049720335751772
a.min(), numel() == 20000 for 200000 times, dtype=torch.double
2.029639278538525
a.min(), numel() == 200000 for 20000 times, dtype=torch.double
1.6363644907251
a.min(), numel() == 20000 for 200000 times, dtype=torch.float
1.3821239182725549
a.min(), numel() == 200000 for 20000 times, dtype=torch.float
0.834847847931087
a.min(), numel() == 20000 for 200000 times, dtype=torch.int16
0.6913397628813982
a.min(), numel() == 200000 for 20000 times, dtype=torch.int16
0.1370067736133933
a.min(), numel() == 20000 for 200000 times, dtype=torch.int32
0.8190992185845971
a.min(), numel() == 200000 for 20000 times, dtype=torch.int32
0.3640836915001273
a.min(), numel() == 20000 for 200000 times, dtype=torch.int64
1.6516661625355482
a.min(), numel() == 200000 for 20000 times, dtype=torch.int64
1.4111155439168215
```
Fixes: https://github.com/pytorch/pytorch/issues/33197
Fix https://github.com/pytorch/pytorch/issues/24728, https://github.com/pytorch/pytorch/issues/24729
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33936
Differential Revision: D20461658
Pulled By: ngimel
fbshipit-source-id: 5749260114ace3ea7b513e32edc805c844a19c8a