pytorch
e01324d0 - Port l1_loss to Aten (#26795)

Commit
5 years ago
Port l1_loss to Aten (#26795) Summary: VitalyFedyunin, This PR is about port L1 lose to Aten: **Test script:** ``` import torch import torch.nn as nn import time torch.manual_seed(0) def _time(): if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() device = "cpu" loss = nn.L1Loss(reduction = 'sum') if torch.cuda.is_available(): device = "cuda" loss = loss.cuda() #warm up for n in [100, 10000]: input = torch.randn(128, n, requires_grad=True, device=device) target = torch.randn(128, n, device=device) for i in range(1000): output = loss(input, target) output.backward() #get running time for n in [100, 10000]: fwd_t = 0 bwd_t = 0 input = torch.randn(128, n, requires_grad=True, device=device) target = torch.randn(128, n, device=device) for i in range(10000): t1 = _time() output = loss(input, target) t2 = _time() output.backward() t3 = _time() fwd_t = fwd_t + (t2 -t1) bwd_t = bwd_t + (t3 - t2) fwd_avg = fwd_t / 10000 * 1000 bwd_avg = bwd_t / 10000 * 1000 print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)." % (n, fwd_avg, bwd_avg)) ``` Test Device: CPU: skx-8180, GPU: Tesla P100. **Perfromance:** Before: ``` GPU: reduction=’mean’ nput size(128, 100) forward time is 0.31 (ms); backwad avg time is 0.09 (ms). input size(128, 10000) forward time is 0.33 (ms); backwad avg time is 0.14 (ms). reduction=’sum’ input size(128, 100) forward time is 0.31 (ms); backwad avg time is 0.10 (ms). input size(128, 10000) forward time is 0.34 (ms); backwad avg time is 0.14 (ms). CPU: reduction=’mean’ input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.10 (ms). input size(128, 10000) forward time is 1.92 (ms); backwad avg time is 2.96 (ms). reduction=’sum’ input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.09 (ms). input size(128, 10000) forward time is 1.96 (ms); backwad avg time is 2.79 (ms). nume_thread = 1: reduction=’mean’ input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.05 (ms). input size(128, 10000) forward time is 1.67 (ms); backwad avg time is 2.50 (ms). reduction=’sum’: input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.05 (ms). input size(128, 10000) forward time is 1.67 (ms); backwad avg time is 2.51 (ms). ``` After: ``` GPU: reduction=’mean’ input size(128, 100) forward time is 0.05 (ms); backwad avg time is 0.10 (ms). input size(128, 10000) forward time is 0.11 (ms); backwad avg time is 0.17 (ms). reduction=’sum’ input size(128, 100) forward time is 0.05 (ms); backwad avg time is 0.08 (ms). input size(128, 10000) forward time is 0.11 (ms); backwad avg time is 0.16 (ms). CPU: reduction=’mean’ input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.05 (ms). input size(128, 10000) forward time is 0.14 (ms); backwad avg time is 0.18 (ms). reduction=’sum’ input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.05 (ms). input size(128, 10000) forward time is 0.15 (ms); backwad avg time is 0.17 (ms). nume_thread = 1: reduction=’mean’: input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.06 (ms). input size(128, 10000) forward time is 1.05 (ms); backwad avg time is 1.72 (ms). reduction=’sum’: input size(128, 100) forward time is 0.03 (ms); backwad avg time is 0.05 (ms). input size(128, 10000) forward time is 1.03 (ms); backwad avg time is 1.71 (ms). ``` How to set number thread? using following script: ``` num_threads=$1 script=$2 last_core=`expr $num_threads - 1` echo "using $num_threads OMP threads" echo "bind cores to 0~$last_core" export OMP_NUM_THREADS=$num_threads export KMP_AFFINITY=granularity=fine,compact,1,0 numactl --physcpubind=0-$last_core --membind=0 python $script ``` and run `./run.sh 1 L1loss.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/26795 Differential Revision: D18140434 Pulled By: VitalyFedyunin fbshipit-source-id: d0b976ec36797f2e6b4e58fbbac89688d29e736f
Author
Parents
Loading