Migrate equal from the TH to Aten (CPU) (#33286)
Summary:
https://github.com/pytorch/pytorch/issues/24697
VitalyFedyunin
glaringlee
Test script:
```Python
import timeit
setup_ones = """
import torch
a = torch.ones(({n}, {n}), dtype={dtype})
b = torch.ones(({n}, {n}), dtype={dtype})
"""
for n, t in [(1000, 10000), (2000, 10000)]:
for dtype in ('torch.bool', 'torch.int', 'torch.long', 'torch.bfloat16', 'torch.float', 'torch.double'):
#for dtype in ('torch.bool', 'torch.int', 'torch.long', 'torch.float', 'torch.double'):
print('torch.ones(({n}, {n})) equal for {t} times {dtype}'.format(n=n, t=t, dtype=dtype))
print(timeit.timeit(stmt='torch.equal(a, b)', setup=setup_ones.format(n=n, dtype=dtype), number=t))
setup_rand = """
import torch
a = torch.rand(({n}, {n}), dtype={dtype})
b = a.clone()
"""
for n, t in [(1000, 10000), (2000, 10000)]:
for dtype in ('torch.float', 'torch.double'):
print('torch.rand(({n}, {n})) for {t} times {dtype}'.format(n=n, t=t, dtype=dtype))
print(timeit.timeit(stmt='torch.equal(a, b)', setup=setup_rand.format(n=n, dtype=dtype), number=t))
setup_non_contiguous = """
import torch
a = torch.rand(({n}, {n}), dtype={dtype})
a2 = a[:, 500:]
a3 = a2.clone()
torch.equal(a2, a3)
"""
for n, t in [(1000, 10000), (2000, 10000)]:
for dtype in ('torch.float', 'torch.double'):
print('non_contiguous torch.rand(({n}, {n})) for {t} times {dtype}'.format(n=n, t=t, dtype=dtype))
print(timeit.timeit(stmt='torch.equal(a2, a3)', setup=setup_non_contiguous.format(n=n, dtype=dtype), number=t))
setup_not_equal = """
import torch
a = torch.rand(({n}, {n}), dtype={dtype})
b = torch.rand(({n}, {n}), dtype={dtype})
torch.equal(a, b)
"""
for n, t in [(1000, 10000), (2000, 10000)]:
for dtype in ('torch.float', 'torch.double'):
print('not equal torch.rand(({n}, {n})) for {t} times {dtype}'.format(n=n, t=t, dtype=dtype))
print(timeit.timeit(stmt='torch.equal(a, b)', setup=setup_not_equal.format(n=n, dtype=dtype), number=t))
```
TH
```
torch.ones((1000, 1000)) equal for 10000 times torch.bool
1.8391206220258027
torch.ones((1000, 1000)) equal for 10000 times torch.int
1.8877864250680432
torch.ones((1000, 1000)) equal for 10000 times torch.long
1.938108820002526
torch.ones((1000, 1000)) equal for 10000 times torch.bfloat16
3.184849138953723
torch.ones((1000, 1000)) equal for 10000 times torch.float
1.8825413499725983
torch.ones((1000, 1000)) equal for 10000 times torch.double
2.7266416549682617
torch.ones((2000, 2000)) equal for 10000 times torch.bool
7.227149627986364
torch.ones((2000, 2000)) equal for 10000 times torch.int
7.76215292501729
torch.ones((2000, 2000)) equal for 10000 times torch.long
9.631909006042406
torch.ones((2000, 2000)) equal for 10000 times torch.bfloat16
8.097328286035918
torch.ones((2000, 2000)) equal for 10000 times torch.float
5.5739822529722005
torch.ones((2000, 2000)) equal for 10000 times torch.double
8.444009944912978
torch.rand((1000, 1000)) for 10000 times torch.float
1.168096570065245
torch.rand((1000, 1000)) for 10000 times torch.double
1.6577326939441264
torch.rand((2000, 2000)) for 10000 times torch.float
5.49395391496364
torch.rand((2000, 2000)) for 10000 times torch.double
8.507486199960113
non_contiguous torch.rand((1000, 1000)) for 10000 times torch.float
6.074504268006422
non_contiguous torch.rand((1000, 1000)) for 10000 times torch.double
6.1426916810451075
non_contiguous torch.rand((2000, 2000)) for 10000 times torch.float
37.501055537955835
non_contiguous torch.rand((2000, 2000)) for 10000 times torch.double
44.6880351039581
not equal torch.rand((1000, 1000)) for 10000 times torch.float
0.029356416082009673
not equal torch.rand((1000, 1000)) for 10000 times torch.double
0.025421109050512314
not equal torch.rand((2000, 2000)) for 10000 times torch.float
0.026333761983551085
not equal torch.rand((2000, 2000)) for 10000 times torch.double
0.02748022007290274
```
ATen
```
torch.ones((1000, 1000)) equal for 10000 times torch.bool
0.7961567062884569
torch.ones((1000, 1000)) equal for 10000 times torch.int
0.49172434909269214
torch.ones((1000, 1000)) equal for 10000 times torch.long
0.9459248608909547
torch.ones((1000, 1000)) equal for 10000 times torch.bfloat16
2.0877483217045665
torch.ones((1000, 1000)) equal for 10000 times torch.float
0.606857153121382
torch.ones((1000, 1000)) equal for 10000 times torch.double
1.1388208279386163
torch.ones((2000, 2000)) equal for 10000 times torch.bool
2.0329296849668026
torch.ones((2000, 2000)) equal for 10000 times torch.int
3.534358019940555
torch.ones((2000, 2000)) equal for 10000 times torch.long
8.19841272290796
torch.ones((2000, 2000)) equal for 10000 times torch.bfloat16
6.595649406313896
torch.ones((2000, 2000)) equal for 10000 times torch.float
4.193911510054022
torch.ones((2000, 2000)) equal for 10000 times torch.double
7.931309659034014
torch.rand((1000, 1000)) for 10000 times torch.float
0.8877940969541669
torch.rand((1000, 1000)) for 10000 times torch.double
1.4142901846207678
torch.rand((2000, 2000)) for 10000 times torch.float
4.010025603231043
torch.rand((2000, 2000)) for 10000 times torch.double
8.126411964651197
non_contiguous torch.rand((1000, 1000)) for 10000 times torch.float
0.602473056409508
non_contiguous torch.rand((1000, 1000)) for 10000 times torch.double
0.6784545010887086
non_contiguous torch.rand((2000, 2000)) for 10000 times torch.float
3.0991827426478267
non_contiguous torch.rand((2000, 2000)) for 10000 times torch.double
5.719010795000941
not equal torch.rand((1000, 1000)) for 10000 times torch.float
0.046060710679739714
not equal torch.rand((1000, 1000)) for 10000 times torch.double
0.036034489050507545
not equal torch.rand((2000, 2000)) for 10000 times torch.float
0.03686975734308362
not equal torch.rand((2000, 2000)) for 10000 times torch.double
0.04189508780837059
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33286
Differential Revision: D22211962
Pulled By: glaringlee
fbshipit-source-id: a5c48f328432c1996f28e19bc75cb495fb689f6b