pytorch
ae0732cd - Speed up an integer to the power of a positive integer on CPU (#26020)

Commit
6 years ago
Speed up an integer to the power of a positive integer on CPU (#26020) Summary: Current integer scalar exps are always cast to double. This commit avoids cast if the tensor is also integral and the scalar is positive to speed up. Benchmark (Debian Buster, g++ 8, Intel(R) Xeon(R) E-2136 CPU @ 3.30GHz 0 0:0 3300.00 MHz , Debug build, Turbo turned off): ```python import timeit for n, t in [(1000, 13000), (10_000, 1300)]: for e in (2, 3, 4): for dtype in ('torch.int16', 'torch.int32', 'torch.int64'): print(f'a.pow({e}) (a.numel() == {n}) for {t} times') print(f'dtype {dtype}, {t} times', end='\t\t') print(timeit.timeit(f'a.pow({e})', setup=f'import torch; a = torch.arange({n}, device="cpu", dtype={dtype})', number=t)) ``` Before: ``` a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 1.6958350749996498 a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 0.7989626339999631 a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 0.7973162800003593 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 1.8660746679997828 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 0.8101709959996697 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 0.8135280149999744 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 5.010833072999958 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 4.801007671999741 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 3.963344578000033 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 1.6216251330001796 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 0.5672429639998882 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 0.5544572270000572 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 1.656308512999658 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 1.502670819999821 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 0.5757876879997639 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 4.775718216999849 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 4.754745475000163 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 3.737249878000057 ``` After: ``` a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 1.1006453190002503 a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 1.0849009019998448 a.pow(2) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 1.093259106000005 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 1.0859826279997833 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 1.1076840900000207 a.pow(3) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 1.0755480369998622 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int16, 13000 times 1.918211066999902 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int32, 13000 times 1.9183043200000611 a.pow(4) (a.numel() == 1000) for 13000 times dtype torch.int64, 13000 times 1.930021430999659 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 0.7271483560002707 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 0.7289002070001516 a.pow(2) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 0.7267536800000016 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 0.7301799359997858 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 0.7289195180001116 a.pow(3) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 0.7270008230002531 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int16, 1300 times 1.5354506029998447 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int32, 1300 times 1.528263066999898 a.pow(4) (a.numel() == 10000) for 1300 times dtype torch.int64, 1300 times 1.5369428439998956 ``` --- Best viewed with whitespace changes turned off Pull Request resolved: https://github.com/pytorch/pytorch/pull/26020 Differential Revision: D17485400 Pulled By: VitalyFedyunin fbshipit-source-id: 3a16b074825a5aab0f7e7af3d8100f9e4b7011a3
Author
Parents
Loading