Migrate pow from TH to Aten (CUDA) (#25517)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/24613
```
DEBUG = 0
OMP_NUM_THREADS = 1
Tesla M40
import torch
base = torch.randn(1000000, device='cuda:1')
exp = torch.randn(1000000, device='cuda:1')
out = torch.empty_like(base)
timeit base.pow(0)
old 53.1 µs ± 22.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 18.7 µs ± 15 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
timeit base.pow(1/3)
old 53.3 µs ± 20.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 51.1 µs ± 101 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-1/3)
old 53.3 µs ± 55.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 51.1 µs ± 29.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(1/2)
old 53.2 µs ± 38.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 34.8 µs ± 40.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-1/2)
old 53.3 µs ± 54.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 42 µs ± 32.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(1)
old 38.3 µs ± 53.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 40.1 µs ± 41.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-1)
old 38.4 µs ± 29 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 35 µs ± 143 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(2)
old 38.1 µs ± 20.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 34.8 µs ± 90.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-2)
old 38.3 µs ± 11.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 35.2 µs ± 54.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(3)
old 38.3 µs ± 164 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 34.9 µs ± 46.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-3)
old 53.3 µs ± 89.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 51.4 µs ± 31.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(123456.789)
old 53.3 µs ± 12.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 51.2 µs ± 24.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-123456.789)
old 53.5 µs ± 152 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 51.3 µs ± 66.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(exp)
old 58.2 µs ± 25.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 54.5 µs ± 25.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(0, exp)
old 49.1 µs ± 89.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 58.7 µs ± 125 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(1, exp)
old 48.7 µs ± 26.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 18.7 µs ± 88.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
timeit torch.pow(-1, exp)
old 50.7 µs ± 104 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 59.8 µs ± 100 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(42, exp)
old 49.4 µs ± 98 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 58.6 µs ± 26.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(-42, exp)
old 50.4 µs ± 131 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 59.8 µs ± 48.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(0, exp, out=out)
old 49 µs ± 13 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 59.2 µs ± 169 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(1, exp, out=out)
old 49.3 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 18.8 µs ± 45.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
timeit torch.pow(-1, exp, out=out)
old 50.4 µs ± 167 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 60.2 µs ± 71.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(42, exp, out=out)
old 49.2 µs ± 293 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 58.9 µs ± 193 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(-42, exp, out=out)
old 50.5 µs ± 150 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 60.1 µs ± 89.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
base = (torch.rand(1000000, device='cuda:1') * 10).to(int)
exp = (torch.rand(1000000, device='cuda:1') * 10).to(int)
out = torch.empty_like(base)
timeit base.pow(0)
old 75.5 µs ± 10.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 33.8 µs ± 84.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(1/3)
old 75.5 µs ± 78.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 842 µs ± 449 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-1/3)
old 75.5 µs ± 24.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 843 µs ± 231 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(1/2)
old 75.7 µs ± 141 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 123 µs ± 71.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-1/2)
old 76 µs ± 162 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 180 µs ± 55.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(1)
old 74.1 µs ± 25.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 72.3 µs ± 32.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-1.0)
old Integers to negative integer powers are not allowed.
new 86.9 µs ± 84.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(2)
old 74.2 µs ± 15.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 66.5 µs ± 28.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-2.0)
old Integers to negative integer powers are not allowed.
new 87.3 µs ± 25.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(3)
old 74.3 µs ± 23.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 66.5 µs ± 43.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-3.0)
old Integers to negative integer powers are not allowed.
new 861 µs ± 372 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(123456.789)
old 256 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 863 µs ± 64.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(-123456.789)
old Integers to negative integer powers are not allowed.
new 863 µs ± 57.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit base.pow(exp)
old 111 µs ± 14.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 98.8 µs ± 16 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(0, exp)
old 81.9 µs ± 23.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 92.9 µs ± 14.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(1, exp)
old 81.9 µs ± 25.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 33.6 µs ± 56.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(-1, exp)
old 82.2 µs ± 15.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 93.6 µs ± 161 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(42, exp)
old 82.1 µs ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 93.8 µs ± 75.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(-42, exp)
old 82.3 µs ± 18.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 94 µs ± 68.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(0, exp, out=out)
old 81.6 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 93.8 µs ± 83.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(1, exp, out=out)
old 81.6 µs ± 26.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 33.7 µs ± 36.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(-1, exp, out=out)
old 82.7 µs ± 119 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 93.9 µs ± 116 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(42, exp, out=out)
old 82.6 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 93.7 µs ± 144 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
timeit torch.pow(-42, exp, out=out)
old 82.5 µs ± 214 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
new 94 µs ± 55.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25517
Differential Revision: D17251364
Pulled By: pbelevich
fbshipit-source-id: 20904c073c311e76285eaa1b68e67e67ea3c62d8