Speed up an integer to the power of a positive integer on CPU (#26020)
Summary:
Current integer scalar exps are always cast to double. This commit avoids cast if the tensor is also
integral and the scalar is positive to speed up.
Benchmark (Debian Buster, g++ 8, Intel(R) Xeon(R) E-2136 CPU @ 3.30GHz 0 0:0 3300.00 MHz , Debug
build, Turbo turned off):
```python
import timeit
for n, t in [(1000, 13000),
(10_000, 1300)]:
for e in (2, 3, 4):
for dtype in ('torch.int16', 'torch.int32', 'torch.int64'):
print(f'a.pow({e}) (a.numel() == {n}) for {t} times')
print(f'dtype {dtype}, {t} times', end='\t\t')
print(timeit.timeit(f'a.pow({e})',
setup=f'import torch; a = torch.arange({n}, device="cpu", dtype={dtype})',
number=t))
```
Before:
```
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 1.6958350749996498
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 0.7989626339999631
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 0.7973162800003593
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 1.8660746679997828
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 0.8101709959996697
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 0.8135280149999744
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 5.010833072999958
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 4.801007671999741
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 3.963344578000033
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 1.6216251330001796
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 0.5672429639998882
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 0.5544572270000572
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 1.656308512999658
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 1.502670819999821
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 0.5757876879997639
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 4.775718216999849
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 4.754745475000163
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 3.737249878000057
```
After:
```
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 1.1006453190002503
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 1.0849009019998448
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 1.093259106000005
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 1.0859826279997833
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 1.1076840900000207
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 1.0755480369998622
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 1.918211066999902
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 1.9183043200000611
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 1.930021430999659
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 0.7271483560002707
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 0.7289002070001516
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 0.7267536800000016
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 0.7301799359997858
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 0.7289195180001116
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 0.7270008230002531
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 1.5354506029998447
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 1.528263066999898
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 1.5369428439998956
```
---
Best viewed with whitespace changes turned off
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26020
Differential Revision: D17485400
Pulled By: VitalyFedyunin
fbshipit-source-id: 3a16b074825a5aab0f7e7af3d8100f9e4b7011a3