Move leaky_relu to Aten(CPU, CUDA) (#29899)
Summary:
VitalyFedyunin, This PR is about port LeakyReLU activation to Aten:
**Test script:**
```
import torch
import torch.nn as nn
import time
torch.manual_seed(0)
def _time():
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
device = "cpu"
m = nn.LeakyReLU()
if torch.cuda.is_available():
device = "cuda"
m = m.cuda()
#warm up
for n in [100, 10000]:
input = torch.randn(128, n, requires_grad=True, device=device)
grad_output = torch.ones(128, n, device=device)
for i in range(1000):
output = m(input)
output.backward(grad_output)
for n in [100, 10000]:
fwd_t = 0
bwd_t = 0
input = torch.randn(128, n, requires_grad=True, device=device)
grad_output = torch.ones(128, n, device=device)
for i in range(10000):
t1 = _time()
output = m(input)
t2 = _time()
output.backward(grad_output)
t3 = _time()
fwd_t = fwd_t + (t2 -t1)
bwd_t = bwd_t + (t3 - t2)
fwd_avg = fwd_t / 10000 * 1000
bwd_avg = bwd_t / 10000 * 1000
print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)."
% (n, fwd_avg, bwd_avg))
```
Test Device: CPU: skx-8180, GPU: Tesla P40.
Perfromance:
Before:
```
GPU:
input size(128, 100) forward time is 0.05 (ms); backwad avg time is 0.11 (ms).
input size(128, 10000) forward time is 0.06 (ms); backwad avg time is 0.17 (ms).
CPU:
OMP_NUM_THREADS=56
input size(128, 100) forward time is 0.05 (ms); backwad avg time is 0.14 (ms).
input size(128, 10000) forward time is 4.21 (ms); backwad avg time is 8.02 (ms).
OMP_NUM_THREADS=1
input size(128, 100) forward time is 0.02 (ms); backwad avg time is 0.07 (ms).
input size(128, 10000) forward time is 1.98 (ms); backwad avg time is 6.21 (ms)
```
After:
```
GPU:
input size(128, 100) forward time is 0.05 (ms); backwad avg time is 0.11 (ms).
input size(128, 10000) forward time is 0.06 (ms); backwad avg time is 0.17 (ms).
CPU:
OMP_NUM_THREADS=56
input size(128, 100) forward time is 0.02 (ms); backwad avg time is 0.04 (ms).
input size(128, 10000) forward time is 0.03 (ms); backwad avg time is 0.09 (ms).
OMP_NUM_THREADS=1
input size(128, 100) forward time is 0.01 (ms); backwad avg time is 0.02 (ms).
input size(128, 10000) forward time is 0.47 (ms); backwad avg time is 1.02 (ms).
```
How to set the numbers of thread? using following script:
```
num_threads=$1
script=$2
last_core=`expr $num_threads - 1`
echo "using $num_threads OMP threads"
echo "bind cores to 0~$last_core"
export OMP_NUM_THREADS=$num_threads
export KMP_AFFINITY=granularity=fine,compact,1,0
numactl --physcpubind=0-$last_core --membind=0 python $script
```
and run .**/run.sh num_threads test.py**.
Fixes https://github.com/pytorch/pytorch/issues/24583 #24584 https://github.com/pytorch/pytorch/issues/24720 #24721
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29899
Differential Revision: D18816231
Pulled By: VitalyFedyunin
fbshipit-source-id: afb1e43a99317d17f50cff1b593cd8f7a0a83da2