Move log_sigmoid to Aten(CPU) (#30958)
Summary:
VitalyFedyunin, This PR is about port LogSigmoid activation to Aten:
Test script:
```
import torch
import torch.nn as nn
import time
torch.manual_seed(0)
def _time():
return time.time()
device = "cpu"
m = nn.LogSigmoid()
#warm up
for n in [1, 10, 100, 1000]:
input = torch.randn(128, n, requires_grad=True, device=device)
grad_output = torch.randn(128, n, device=device)
for i in range(1000):
output = m(input)
output.backward(grad_output)
for n in [1, 10, 100, 1000]:
input = torch.randn(128, n, requires_grad=True, device=device)
grad_output = torch.randn(128, n, device=device)
fwd_t = 0
bwd_t = 0
for i in range(10000):
t1 = _time()
output = m(input)
t2 = _time()
output.backward(grad_output)
t3 = _time()
fwd_t = fwd_t + (t2 -t1)
bwd_t = bwd_t + (t3 - t2)
fwd_avg = fwd_t / 10000 * 1000
bwd_avg = bwd_t / 10000 * 1000
print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)."
% (n, fwd_avg, bwd_avg))
```
**Before:**
```
input size(128, 1) forward time is 0.02 (ms); backwad avg time is 0.02 (ms).
input size(128, 10) forward time is 0.10 (ms); backwad avg time is 0.03 (ms).
input size(128, 100) forward time is 0.90 (ms); backwad avg time is 0.09 (ms).
input size(128, 1000) forward time is 9.04 (ms); backwad avg time is 0.87 (ms).
```
**After:**
```
input size(128, 1) forward time is 0.02 (ms); backwad avg time is 0.02 (ms).
input size(128, 10) forward time is 0.02 (ms); backwad avg time is 0.02 (ms).
input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.03 (ms).
input size(128, 1000) forward time is 0.28 (ms); backwad avg time is 0.07 (ms).
```
**OMP_NUM_THREADS=1:**
```
Before:
input size(128, 1) forward time is 0.02 (ms); backwad avg time is 0.02 (ms).
input size(128, 10) forward time is 0.10 (ms); backwad avg time is 0.03 (ms).
input size(128, 100) forward time is 0.88 (ms); backwad avg time is 0.10 (ms).
input size(128, 1000) forward time is 8.72 (ms); backwad avg time is 0.81 (ms).
After:
input size(128, 1) forward time is 0.01 (ms); backwad avg time is 0.02 (ms).
input size(128, 10) forward time is 0.02 (ms); backwad avg time is 0.02 (ms).
input size(128, 100) forward time is 0.07 (ms); backwad avg time is 0.03 (ms).
input size(128, 1000) forward time is 0.63 (ms); backwad avg time is 0.15 (ms).
```
Fix https://github.com/pytorch/pytorch/issues/24724, https://github.com/pytorch/pytorch/issues/24725.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30958
Differential Revision: D19275111
Pulled By: ezyang
fbshipit-source-id: bbfe82e58fb27a4fb21c1914c6547a9050072e5c