pytorch
9033ace9 - Migrate soft_margin_loss from the TH to Aten (CUDA+CPU) (#27673)

Commit
5 years ago
Migrate soft_margin_loss from the TH to Aten (CUDA+CPU) (#27673) Summary: Replaces fused TH kernels with a 2-liner of regular Tensor functions. Benchmarking revealed that performance improves compared to PyTorch 1.2. Refs: https://github.com/pytorch/pytorch/issues/24631, https://github.com/pytorch/pytorch/issues/24632, https://github.com/pytorch/pytorch/issues/24764, https://github.com/pytorch/pytorch/issues/24765 VitalyFedyunin ### Benchmarking results on my laptop: ## 1.4.0a0+f63c9e8 output ``` PyTorch version: 1.4.0a0+f63c9e8 CPU Operator sanity check: tensor(0.5926, grad_fn=<MeanBackward0>) tensor([-0.0159, -0.0170, -0.0011, -0.0083, -0.0140, -0.0217, -0.0290, -0.0262, -0.0078, -0.0129]) double backward tensor(-0.1540, grad_fn=<SumBackward0>) ok GPU Operator sanity check: tensor(0.5601, device='cuda:0', grad_fn=<MeanBackward0>) tensor([-0.0393, -0.0316, -0.0233, -0.0140, -0.0141, -0.0161, -0.0322, -0.0238, -0.0054, -0.0151], device='cuda:0') double backward tensor(-0.2148, device='cuda:0', grad_fn=<SumBackward0>) ok CPU warmup 1000 took 9.025700273923576e-05 CPU warmup 10000 took 0.0009383050055475906 CPU warmup 100000 took 0.0015631120040779933 CPU warmup TOTAL time 0.0026368020044174045 CPU forward 1000 took 6.919399311300367e-05 CPU forward 10000 took 0.00014462800754699856 CPU forward 100000 took 0.0011234670091653243 CPU forward 1000000 took 0.014555767003912479 CPU forward 10000000 took 0.13409724000666756 CPU forward 100000000 took 1.246048310000333 CPU forward TOTAL time 1.3961777170043206 CPU for- & backward 1000 took 0.0003219560021534562 CPU for- & backward 10000 took 0.00037290599721018225 CPU for- & backward 100000 took 0.001975035003852099 CPU for- & backward 1000000 took 0.02621342398924753 CPU for- & backward 10000000 took 0.2944270490115741 CPU for- & backward 100000000 took 1.6856628700043075 CPU for- & backward TOTAL time 2.0091958299890393 GPU warmup 1000 took 0.0002462909906171262 GPU warmup 10000 took 9.991199476644397e-05 GPU warmup 100000 took 0.00034347400651313365 GPU warmup TOTAL time 0.0007382350013358518 GPU forward 1000 took 9.67290106927976e-05 GPU forward 10000 took 9.349700121674687e-05 GPU forward 100000 took 9.384499571751803e-05 GPU forward 1000000 took 0.0004975290066795424 GPU forward 10000000 took 0.0017606960027478635 GPU forward 100000000 took 0.003572814996005036 GPU forward TOTAL time 0.006185991995153017 GPU for- & backward 1000 took 0.00035818999458570033 GPU for- & backward 10000 took 0.0003240450023440644 GPU for- & backward 100000 took 0.0003223370003979653 GPU for- & backward 1000000 took 0.00036740700306836516 GPU for- & backward 10000000 took 0.0003690610028570518 GPU for- & backward 100000000 took 0.0003672500024549663 GPU for- & backward TOTAL time 0.002197896988946013 ``` ## 1.2 output ``` PyTorch version: 1.2.0 CPU Operator sanity check: tensor(0.5926, grad_fn=<SoftMarginLossBackward>) tensor([-0.0159, -0.0170, -0.0011, -0.0083, -0.0140, -0.0217, -0.0290, -0.0262, -0.0078, -0.0129]) double backward tensor(-0.1540, grad_fn=<SumBackward0>) ok GPU Operator sanity check: tensor(0.5601, device='cuda:0', grad_fn=<SoftMarginLossBackward>) tensor([-0.0393, -0.0316, -0.0233, -0.0140, -0.0141, -0.0161, -0.0322, -0.0238, -0.0054, -0.0151], device='cuda:0') double backward tensor(-0.2148, device='cuda:0', grad_fn=<SumBackward0>) ok CPU warmup 1000 took 8.422900282312185e-05 CPU warmup 10000 took 0.00036992700188420713 CPU warmup 100000 took 0.003682684007799253 CPU warmup TOTAL time 0.004169487991021015 CPU forward 1000 took 5.521099956240505e-05 CPU forward 10000 took 0.00036948200431652367 CPU forward 100000 took 0.003762389998883009 CPU forward 1000000 took 0.03725024699815549 CPU forward 10000000 took 0.3614480490068672 CPU forward 100000000 took 3.6139175269927364 CPU forward TOTAL time 4.016912263003178 CPU for- & backward 1000 took 0.0002734809968387708 CPU for- & backward 10000 took 0.0006605249946005642 CPU for- & backward 100000 took 0.005437346000690013 CPU for- & backward 1000000 took 0.051245586000732146 CPU for- & backward 10000000 took 0.5291594529990107 CPU for- & backward 100000000 took 5.23841712900321 CPU for- & backward TOTAL time 5.8253340990049765 GPU warmup 1000 took 0.0005757809994975105 GPU warmup 10000 took 0.0004058420017827302 GPU warmup 100000 took 0.0003764610009966418 GPU warmup TOTAL time 0.0013992580061312765 GPU forward 1000 took 0.0003543390048434958 GPU forward 10000 took 0.0003633670130511746 GPU forward 100000 took 0.0004807310033356771 GPU forward 1000000 took 0.0005875999922864139 GPU forward 10000000 took 0.0016903509967960417 GPU forward 100000000 took 0.014400018990272656 GPU forward TOTAL time 0.0179396449966589 GPU for- & backward 1000 took 0.0006167769897729158 GPU for- & backward 10000 took 0.0006845899915788323 GPU for- & backward 100000 took 0.000631830989732407 GPU for- & backward 1000000 took 0.0010741150035755709 GPU for- & backward 10000000 took 0.0017265130009036511 GPU for- & backward 100000000 took 0.014847910992102697 GPU for- & backward TOTAL time 0.01965981800458394 ``` ### Code used for performance test ``` import torch import torch.nn.functional as F import torch.nn as nn from timeit import default_timer torch.manual_seed(0) cpu = torch.device('cpu') gpu = torch.device('cuda') loss_fn = F.soft_margin_loss def run_benchmark(name, depth, require_grad, device, fn): total_start = default_timer() for i in range(3, 3 + depth): start = default_timer() n = 10 ** i a = torch.rand(n, requires_grad=require_grad, device=device) b = torch.rand(n, device=device) fn(a, b) end = default_timer() print('{} {} took {}'.format(name, n, end-start)) total_end = default_timer() print('{} TOTAL time {}'.format(name, total_end-total_start)) def fwd_only(a, b): out = loss_fn(a, b) def fwd_bck(a, b): out = loss_fn(a, b) out.backward() def sanity_check(name, device): print('{} Operator sanity check:'.format(name)) a = torch.rand(10, requires_grad=True, device=device) b = torch.rand(10, device=device) out = loss_fn(a,b) print(out) out.backward() print(a.grad) print('double backward') loss = loss_fn(a, b) loss2 = torch.autograd.grad(loss, a, create_graph=True) z = loss2[0].sum() print(z) z.backward() print('ok') print() print('PyTorch version:', torch.__version__) sanity_check('CPU', cpu) sanity_check('GPU', gpu) print() run_benchmark('CPU warmup', 3, False, cpu, fwd_only) run_benchmark('CPU forward', 6, False, cpu, fwd_only) run_benchmark('CPU for- & backward', 6, True, cpu, fwd_bck) print() run_benchmark('GPU warmup', 3, False, gpu, fwd_only) run_benchmark('GPU forward', 6, False, gpu, fwd_only) run_benchmark('GPU for- & backward', 6, True, gpu, fwd_bck) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/27673 Differential Revision: D17889288 Pulled By: ezyang fbshipit-source-id: 9ddffe4dbbfab6180847a8fec32443910f18f0a9
Author
Parents
Loading