pytorch
72f2c479 - Migrate equal from the TH to Aten (CPU) (#33286)

Commit
4 years ago
Migrate equal from the TH to Aten (CPU) (#33286) Summary: https://github.com/pytorch/pytorch/issues/24697 VitalyFedyunin glaringlee Test script: ```Python import timeit setup_ones = """ import torch a = torch.ones(({n}, {n}), dtype={dtype}) b = torch.ones(({n}, {n}), dtype={dtype}) """ for n, t in [(1000, 10000), (2000, 10000)]: for dtype in ('torch.bool', 'torch.int', 'torch.long', 'torch.bfloat16', 'torch.float', 'torch.double'): #for dtype in ('torch.bool', 'torch.int', 'torch.long', 'torch.float', 'torch.double'): print('torch.ones(({n}, {n})) equal for {t} times {dtype}'.format(n=n, t=t, dtype=dtype)) print(timeit.timeit(stmt='torch.equal(a, b)', setup=setup_ones.format(n=n, dtype=dtype), number=t)) setup_rand = """ import torch a = torch.rand(({n}, {n}), dtype={dtype}) b = a.clone() """ for n, t in [(1000, 10000), (2000, 10000)]: for dtype in ('torch.float', 'torch.double'): print('torch.rand(({n}, {n})) for {t} times {dtype}'.format(n=n, t=t, dtype=dtype)) print(timeit.timeit(stmt='torch.equal(a, b)', setup=setup_rand.format(n=n, dtype=dtype), number=t)) setup_non_contiguous = """ import torch a = torch.rand(({n}, {n}), dtype={dtype}) a2 = a[:, 500:] a3 = a2.clone() torch.equal(a2, a3) """ for n, t in [(1000, 10000), (2000, 10000)]: for dtype in ('torch.float', 'torch.double'): print('non_contiguous torch.rand(({n}, {n})) for {t} times {dtype}'.format(n=n, t=t, dtype=dtype)) print(timeit.timeit(stmt='torch.equal(a2, a3)', setup=setup_non_contiguous.format(n=n, dtype=dtype), number=t)) setup_not_equal = """ import torch a = torch.rand(({n}, {n}), dtype={dtype}) b = torch.rand(({n}, {n}), dtype={dtype}) torch.equal(a, b) """ for n, t in [(1000, 10000), (2000, 10000)]: for dtype in ('torch.float', 'torch.double'): print('not equal torch.rand(({n}, {n})) for {t} times {dtype}'.format(n=n, t=t, dtype=dtype)) print(timeit.timeit(stmt='torch.equal(a, b)', setup=setup_not_equal.format(n=n, dtype=dtype), number=t)) ``` TH ``` torch.ones((1000, 1000)) equal for 10000 times torch.bool 1.8391206220258027 torch.ones((1000, 1000)) equal for 10000 times torch.int 1.8877864250680432 torch.ones((1000, 1000)) equal for 10000 times torch.long 1.938108820002526 torch.ones((1000, 1000)) equal for 10000 times torch.bfloat16 3.184849138953723 torch.ones((1000, 1000)) equal for 10000 times torch.float 1.8825413499725983 torch.ones((1000, 1000)) equal for 10000 times torch.double 2.7266416549682617 torch.ones((2000, 2000)) equal for 10000 times torch.bool 7.227149627986364 torch.ones((2000, 2000)) equal for 10000 times torch.int 7.76215292501729 torch.ones((2000, 2000)) equal for 10000 times torch.long 9.631909006042406 torch.ones((2000, 2000)) equal for 10000 times torch.bfloat16 8.097328286035918 torch.ones((2000, 2000)) equal for 10000 times torch.float 5.5739822529722005 torch.ones((2000, 2000)) equal for 10000 times torch.double 8.444009944912978 torch.rand((1000, 1000)) for 10000 times torch.float 1.168096570065245 torch.rand((1000, 1000)) for 10000 times torch.double 1.6577326939441264 torch.rand((2000, 2000)) for 10000 times torch.float 5.49395391496364 torch.rand((2000, 2000)) for 10000 times torch.double 8.507486199960113 non_contiguous torch.rand((1000, 1000)) for 10000 times torch.float 6.074504268006422 non_contiguous torch.rand((1000, 1000)) for 10000 times torch.double 6.1426916810451075 non_contiguous torch.rand((2000, 2000)) for 10000 times torch.float 37.501055537955835 non_contiguous torch.rand((2000, 2000)) for 10000 times torch.double 44.6880351039581 not equal torch.rand((1000, 1000)) for 10000 times torch.float 0.029356416082009673 not equal torch.rand((1000, 1000)) for 10000 times torch.double 0.025421109050512314 not equal torch.rand((2000, 2000)) for 10000 times torch.float 0.026333761983551085 not equal torch.rand((2000, 2000)) for 10000 times torch.double 0.02748022007290274 ``` ATen ``` torch.ones((1000, 1000)) equal for 10000 times torch.bool 0.7961567062884569 torch.ones((1000, 1000)) equal for 10000 times torch.int 0.49172434909269214 torch.ones((1000, 1000)) equal for 10000 times torch.long 0.9459248608909547 torch.ones((1000, 1000)) equal for 10000 times torch.bfloat16 2.0877483217045665 torch.ones((1000, 1000)) equal for 10000 times torch.float 0.606857153121382 torch.ones((1000, 1000)) equal for 10000 times torch.double 1.1388208279386163 torch.ones((2000, 2000)) equal for 10000 times torch.bool 2.0329296849668026 torch.ones((2000, 2000)) equal for 10000 times torch.int 3.534358019940555 torch.ones((2000, 2000)) equal for 10000 times torch.long 8.19841272290796 torch.ones((2000, 2000)) equal for 10000 times torch.bfloat16 6.595649406313896 torch.ones((2000, 2000)) equal for 10000 times torch.float 4.193911510054022 torch.ones((2000, 2000)) equal for 10000 times torch.double 7.931309659034014 torch.rand((1000, 1000)) for 10000 times torch.float 0.8877940969541669 torch.rand((1000, 1000)) for 10000 times torch.double 1.4142901846207678 torch.rand((2000, 2000)) for 10000 times torch.float 4.010025603231043 torch.rand((2000, 2000)) for 10000 times torch.double 8.126411964651197 non_contiguous torch.rand((1000, 1000)) for 10000 times torch.float 0.602473056409508 non_contiguous torch.rand((1000, 1000)) for 10000 times torch.double 0.6784545010887086 non_contiguous torch.rand((2000, 2000)) for 10000 times torch.float 3.0991827426478267 non_contiguous torch.rand((2000, 2000)) for 10000 times torch.double 5.719010795000941 not equal torch.rand((1000, 1000)) for 10000 times torch.float 0.046060710679739714 not equal torch.rand((1000, 1000)) for 10000 times torch.double 0.036034489050507545 not equal torch.rand((2000, 2000)) for 10000 times torch.float 0.03686975734308362 not equal torch.rand((2000, 2000)) for 10000 times torch.double 0.04189508780837059 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/33286 Differential Revision: D22211962 Pulled By: glaringlee fbshipit-source-id: a5c48f328432c1996f28e19bc75cb495fb689f6b
Author
Parents
Loading