Move Softshrink activation to Aten(CPU+CUDA) (#30229)
Summary:
VitalyFedyunin, This PR is about port Softshrink activation to Aten:
**Test script:**
```
import torch
import torch.nn as nn
import time
torch.manual_seed(0)
def _time():
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
device = "cpu"
m = nn.Softshrink()
if torch.cuda.is_available():
device = "cuda"
m = m.cuda()
#warm up
for n in [100, 10000]:
input = torch.randn(128, n, requires_grad=True, device=device)
grad_output = torch.ones(128, n, device=device)
for i in range(1000):
output = m(input)
output.backward(grad_output)
for n in [100, 10000]:
input = torch.randn(128, n, requires_grad=True, device=device)
grad_output = torch.ones(128, n, device=device)
fwd_t = 0
bwd_t = 0
for i in range(10000):
t1 = _time()
output = m(input)
t2 = _time()
output.backward(grad_output)
t3 = _time()
fwd_t = fwd_t + (t2 -t1)
bwd_t = bwd_t + (t3 - t2)
fwd_avg = fwd_t / 10000 * 1000
bwd_avg = bwd_t / 10000 * 1000
print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)."
% (n, fwd_avg, bwd_avg))
```
Test Device: CPU: skx-8180, GPU: Tesla P40.
Perfromance:
Before:
```
GPU:
input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.12 (ms).
input size(128, 10000) forward time is 0.06 (ms); backwad avg time is 0.18 (ms).
CPU:
input size(128, 100) forward time is 0.19 (ms); backwad avg time is 0.23 (ms).
input size(128, 10000) forward time is 17.23 (ms); backwad avg time is 16.83 (ms).
```
After:
```
GPU:
input size(128, 100) forward time is 0.05 (ms); backwad avg time is 0.11 (ms).
input size(128, 10000) forward time is 0.06 (ms); backwad avg time is 0.17 (ms).
CPU:
input size(128, 100) forward time is 0.08 (ms); backwad avg time is 0.05 (ms).
input size(128, 10000) forward time is 0.32 (ms); backwad avg time is 0.08 (ms).
```
`OMP_NUM_THREADS=1:`
```
Before:
input size(128, 100) forward time is 0.08 (ms); backwad avg time is 0.10 (ms).
input size(128, 10000) forward time is 7.58 (ms); backwad avg time is 7.91 (ms).
After:
input size(128, 100) forward time is 0.08 (ms); backwad avg time is 0.02 (ms).
input size(128, 10000) forward time is 7.30 (ms); backwad avg time is 1.02 (ms).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30229
Differential Revision: D18810054
Pulled By: VitalyFedyunin
fbshipit-source-id: e19074824396570db45ba488ae4f9fe1b07a5839