pytorch
be3808d3 - Migrate `smooth_l1_loss` from the TH to Aten (CPU & CUDA) (#27962)

Commit
5 years ago
Migrate `smooth_l1_loss` from the TH to Aten (CPU & CUDA) (#27962) Summary: This is a port of the TH `SmoothL1Criterion` to ATen using TensorIterator. The forward implementation has been placed in BinaryOpsKernel.cpp/.cu while the backward version was added to PointwiseOpsKernel.cpp/.cu. CPU performance has improved for both forward & backward path. With CUDA the performance of the forward pass has slightly degraded compared to the TH implementation (see benchmark results). ### Questions: 1. Is the storage location of the implementation ok (I followed https://github.com/pytorch/pytorch/pull/26529) or should we create a separate .cpp/.h file pair for each operator implementation (e.g. to keep things together)? 2. The GPU forward-pass now seems to take consistently longer than the old version. Any ideas what we could try to bring it on par with the old impl? ## WITH patch benchmark result: ``` CPU warmup 1000 took 0.00018124299822375178 CPU warmup 10000 took 0.00021713999740313739 CPU warmup 100000 took 0.0016273759974865243 CPU warmup TOTAL time 0.0020758909959113225 CPU forward 1000 took 6.229899736354128e-05 CPU forward 10000 took 0.00013340599980438128 CPU forward 100000 took 0.0008730469999136403 CPU forward 1000000 took 0.011010036003426649 CPU forward 10000000 took 0.11133221499767387 CPU forward 100000000 took 1.0425375220002024 CPU forward TOTAL time 1.1660894790038583 CPU for- & backward 1000 took 0.0002662249971763231 CPU for- & backward 10000 took 0.00023712700203759596 CPU for- & backward 100000 took 0.002531945996452123 CPU for- & backward 1000000 took 0.010394354998425115 CPU for- & backward 10000000 took 0.23814761800167616 CPU for- & backward 100000000 took 1.2651235049997922 CPU for- & backward TOTAL time 1.516897434994462 GPU warmup 1000 took 0.00020941899856552482 GPU warmup 10000 took 8.128300396492705e-05 GPU warmup 100000 took 8.551499922759831e-05 GPU warmup TOTAL time 0.0004199420000077225 GPU forward 1000 took 7.060499774524942e-05 GPU forward 10000 took 7.116600318113342e-05 GPU forward 100000 took 9.825800225371495e-05 GPU forward 1000000 took 0.000499356996442657 GPU forward 10000000 took 0.002032470001722686 GPU forward 100000000 took 0.018638986002770253 GPU forward TOTAL time 0.02148268099699635 GPU for- & backward 1000 took 0.00035967300209449604 GPU for- & backward 10000 took 0.00032710300001781434 GPU for- & backward 100000 took 0.0003689270015456714 GPU for- & backward 1000000 took 0.0007732619997113943 GPU for- & backward 10000000 took 0.02127284000016516 GPU for- & backward 100000000 took 0.2022330649997457 GPU for- & backward TOTAL time 0.2254496300010942 ``` ## WITHOUT patch benchmark result: ``` CPU warmup 1000 took 0.00011545199959073216 CPU warmup 10000 took 0.00016227000014623627 CPU warmup 100000 took 0.0013456509987008758 CPU warmup TOTAL time 0.001648657998885028 CPU forward 1000 took 2.627600042615086e-05 CPU forward 10000 took 0.00015939700097078457 CPU forward 100000 took 0.001139313004387077 CPU forward 1000000 took 0.013769682998827193 CPU forward 10000000 took 0.13163026500114938 CPU forward 100000000 took 1.321879123999679 CPU forward TOTAL time 1.4687001089987461 CPU for- & backward 1000 took 0.0002569290008977987 CPU for- & backward 10000 took 0.00033315900509478524 CPU for- & backward 100000 took 0.0016096779945655726 CPU for- & backward 1000000 took 0.014474845003860537 CPU for- & backward 10000000 took 0.1564881520025665 CPU for- & backward 100000000 took 1.5787935900007142 CPU for- & backward TOTAL time 1.7521004869995522 GPU warmup 1000 took 0.00025611399905756116 GPU warmup 10000 took 0.00014123699656920508 GPU warmup 100000 took 0.00012580600014189258 GPU warmup TOTAL time 0.0005591579974861816 GPU forward 1000 took 0.00031183200189843774 GPU forward 10000 took 0.00011483799607958645 GPU forward 100000 took 0.00010807999933604151 GPU forward 1000000 took 0.0007842139966669492 GPU forward 10000000 took 0.0017624700049054809 GPU forward 100000000 took 0.01519905700115487 GPU forward TOTAL time 0.018341148999752477 GPU for- & backward 1000 took 0.00047569099842803553 GPU for- & backward 10000 took 0.0003539700046530925 GPU for- & backward 100000 took 0.000808880002296064 GPU for- & backward 1000000 took 0.001639469999645371 GPU for- & backward 10000000 took 0.021154599002329633 GPU for- & backward 100000000 took 0.19268552300491137 GPU for- & backward TOTAL time 0.2172460189976846 ``` ### Code used for perforrmance testing ``` import torch import torch.nn.functional as F import torch.nn as nn from timeit import default_timer torch.manual_seed(0) cpu = torch.device('cpu') gpu = torch.device('cuda') loss_fn = F.smooth_l1_loss def run_benchmark(name, depth, require_grad, device, fn): total_start = default_timer() y = None a = None for i in range(3, 3 + depth): start = default_timer() n = 10 ** i a = torch.rand(n, requires_grad=require_grad, device=device) b = torch.rand(n, device=device) y = fn(a, b) y.cpu() # get result (potentially wait for gpu) if a.grad is not None: a.grad.cpu() end = default_timer() print('{} {} took {}'.format(name, n, end-start)) total_end = default_timer() print('{} TOTAL time {}'.format(name, total_end-total_start)) def fwd_only(a, b): out = loss_fn(a, b) return out def fwd_bck(a, b): out = loss_fn(a, b) out.backward() return out def sanity_check(name, device): print('{} Operator sanity check:'.format(name)) a = torch.randn(16, requires_grad=True, device=device) b = torch.randn(16, device=device) * 2 out = loss_fn(a, b) print('out', out) out.backward() print(a.grad) print('double backward') loss = loss_fn(a, b) loss2 = torch.autograd.grad(loss, a, create_graph=True) z = loss2[0].sum() print(z) z.backward() print('ok') print() print('PyTorch version:', torch.__version__) sanity_check('CPU', cpu) if torch.cuda.is_available(): sanity_check('GPU', gpu) print() run_benchmark('CPU warmup', 3, False, cpu, fwd_only) run_benchmark('CPU forward', 6, False, cpu, fwd_only) run_benchmark('CPU for- & backward', 6, True, cpu, fwd_bck) print() if torch.cuda.is_available(): run_benchmark('GPU warmup', 3, False, gpu, fwd_only) run_benchmark('GPU forward', 6, False, gpu, fwd_only) run_benchmark('GPU for- & backward', 6, True, gpu, fwd_bck) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/27962 Differential Revision: D18061942 Pulled By: ezyang fbshipit-source-id: 0d1fc528b59d47d4773b03240c3368db021cb9db
Author
Parents
Loading