Port mse_lose to ATen (#26529)
Summary:
VitalyFedyunin, This PR is about port mse lose to Aten:
**Test script:**
```
import torch
import torch.nn as nn
import time
def _time():
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.time()
device = "cpu"
loss = nn.MSELoss(reduction = 'sum')
if torch.cuda.is_available():
device = "cuda"
loss = loss.cuda()
#warm up
for n in [100, 10000]:
input = torch.randn(128, n, requires_grad=True, device=device)
target = torch.randn(128, n, device=device)
for i in range(1000):
output = loss(input, target)
output.backward()
#get running time
for n in [100, 10000]:
fwd_t = 0
bwd_t = 0
input = torch.randn(128, n, requires_grad=True, device=device)
target = torch.randn(128, n, device=device)
for i in range(10000):
t1 = _time()
output = loss(input, target)
t2 = _time()
output.backward()
t3 = _time()
fwd_t = fwd_t + (t2 -t1)
bwd_t = bwd_t + (t3 - t2)
fwd_avg = fwd_t / 10000 * 1000
bwd_avg = bwd_t / 10000 * 1000
print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)."
% (n, fwd_avg, bwd_avg))
```
**Test Device:** CPU: skx-8180, GPU: Tesla P40.
### Perfromance:
**Before:**
```
GPU:
reduction=’mean’
input size(128, 100) forward time is 0.08 (ms); backwad avg time is 0.14 (ms).
input size(128, 10000) forward time is 0.12 (ms); backwad avg time is 0.21 (ms).
reduction=’sum’
input size(128, 100) forward time is 0.09 (ms); backwad avg time is 0.15 (ms).
input size(128, 10000) forward time is 0.11 (ms); backwad avg time is 0.20 (ms).
CPU:
OMP_NUM_THREADS=56
reduction=’mean’
input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.09 (ms).
input size(128, 10000) forward time is 3.49 (ms); backwad avg time is 3.23 (ms).
reduction=’sum’
input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.09 (ms).
input size(128, 10000) forward time is 3.49 (ms); backwad avg time is 3.23 (ms).
OMP_NUM_THREADS=1
reduction=’mean’
input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.04 (ms).
input size(128, 10000) forward time is 1.41 (ms); backwad avg time is 1.66 (ms).
reduction=’sum’
input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.04 (ms).
input size(128, 10000) forward time is 1.44 (ms); backwad avg time is 1.68 (ms).
```
**After:**
```
GPU:
reduction=’mean’
input size(128, 100) forward time is 0.07 (ms); backwad avg time is 0.13 (ms).
input size(128, 10000) forward time is 0.13 (ms); backwad avg time is 0.20 (ms).
reduction=’sum’
input size(128, 100) forward time is 0.07 (ms); backwad avg time is 0.14 (ms).
input size(128, 10000) forward time is 0.13 (ms); backwad avg time is 0.20 (ms).
CPU:
OMP_NUM_THREADS=56
reduction=’mean’
input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.06 (ms).
input size(128, 10000) forward time is 0.14 (ms); backwad avg time is 0.30 (ms).
reduction=’sum’
input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.06 (ms).
input size(128, 10000) forward time :qis 0.13 (ms); backwad avg time is 0.30 (ms).
OMP_NUM_THREADS=1
reduction=’mean’
input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.05 (ms).
input size(128, 10000) forward time is 0.85 (ms); backwad avg time is 1.27 (ms).
reduction=’sum’
input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.04 (ms).
input size(128, 10000) forward time is 0.83 (ms); backwad avg time is 1.27 (ms).
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26529
Differential Revision: D18225144
Pulled By: VitalyFedyunin
fbshipit-source-id: ce837a297c70398a3ffa22f26ee9e812cf60d128