pytorch
c8771f5a - Port mse_lose to ATen (#26529)

Commit
5 years ago
Port mse_lose to ATen (#26529) Summary: VitalyFedyunin, This PR is about port mse lose to Aten: **Test script:** ``` import torch import torch.nn as nn import time def _time(): if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() device = "cpu" loss = nn.MSELoss(reduction = 'sum') if torch.cuda.is_available(): device = "cuda" loss = loss.cuda() #warm up for n in [100, 10000]: input = torch.randn(128, n, requires_grad=True, device=device) target = torch.randn(128, n, device=device) for i in range(1000): output = loss(input, target) output.backward() #get running time for n in [100, 10000]: fwd_t = 0 bwd_t = 0 input = torch.randn(128, n, requires_grad=True, device=device) target = torch.randn(128, n, device=device) for i in range(10000): t1 = _time() output = loss(input, target) t2 = _time() output.backward() t3 = _time() fwd_t = fwd_t + (t2 -t1) bwd_t = bwd_t + (t3 - t2) fwd_avg = fwd_t / 10000 * 1000 bwd_avg = bwd_t / 10000 * 1000 print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)." % (n, fwd_avg, bwd_avg)) ``` **Test Device:** CPU: skx-8180, GPU: Tesla P40. ### Perfromance: **Before:** ``` GPU: reduction=’mean’ input size(128, 100) forward time is 0.08 (ms); backwad avg time is 0.14 (ms). input size(128, 10000) forward time is 0.12 (ms); backwad avg time is 0.21 (ms). reduction=’sum’ input size(128, 100) forward time is 0.09 (ms); backwad avg time is 0.15 (ms). input size(128, 10000) forward time is 0.11 (ms); backwad avg time is 0.20 (ms). CPU: OMP_NUM_THREADS=56 reduction=’mean’ input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.09 (ms). input size(128, 10000) forward time is 3.49 (ms); backwad avg time is 3.23 (ms). reduction=’sum’ input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.09 (ms). input size(128, 10000) forward time is 3.49 (ms); backwad avg time is 3.23 (ms). OMP_NUM_THREADS=1 reduction=’mean’ input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.04 (ms). input size(128, 10000) forward time is 1.41 (ms); backwad avg time is 1.66 (ms). reduction=’sum’ input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.04 (ms). input size(128, 10000) forward time is 1.44 (ms); backwad avg time is 1.68 (ms). ``` **After:** ``` GPU: reduction=’mean’ input size(128, 100) forward time is 0.07 (ms); backwad avg time is 0.13 (ms). input size(128, 10000) forward time is 0.13 (ms); backwad avg time is 0.20 (ms). reduction=’sum’ input size(128, 100) forward time is 0.07 (ms); backwad avg time is 0.14 (ms). input size(128, 10000) forward time is 0.13 (ms); backwad avg time is 0.20 (ms). CPU: OMP_NUM_THREADS=56 reduction=’mean’ input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.06 (ms). input size(128, 10000) forward time is 0.14 (ms); backwad avg time is 0.30 (ms). reduction=’sum’ input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.06 (ms). input size(128, 10000) forward time :qis 0.13 (ms); backwad avg time is 0.30 (ms). OMP_NUM_THREADS=1 reduction=’mean’ input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.05 (ms). input size(128, 10000) forward time is 0.85 (ms); backwad avg time is 1.27 (ms). reduction=’sum’ input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.04 (ms). input size(128, 10000) forward time is 0.83 (ms); backwad avg time is 1.27 (ms). ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/26529 Differential Revision: D18225144 Pulled By: VitalyFedyunin fbshipit-source-id: ce837a297c70398a3ffa22f26ee9e812cf60d128
Author
Parents
Loading