pytorch
9730d91a - MAINT Migrates multilabel_margin_loss from THC to ATen (CUDA) (#60708)

Commit
4 years ago
MAINT Migrates multilabel_margin_loss from THC to ATen (CUDA) (#60708) Summary: Fixes https://github.com/pytorch/pytorch/issues/24603 Fixes https://github.com/pytorch/pytorch/issues/24602 <s>The implementation should be exactly the same, so it is strange that the benchmarks show such a significant improvement in this PR.</s> The benchmarks are now the same. <details> <summary>Benchmark script</summary> ```python from itertools import product import torch import torch.nn as nn import torch.nn.functional as F import time torch.manual_seed(0) MS_PER_SECOND = 1000 def _time(): torch.cuda.synchronize() return time.perf_counter() * MS_PER_SECOND device = "cuda" C = 30 n_runs = 100 reductions = ["none", "sum", "mean"] Ns = [1_000, 10_000, 100_000] for reduction, N in product(reductions, Ns): total_fwd_time = 0 total_back_time = 0 grad_out = torch.randn(N, device=device) if reduction != "none": grad_out = grad_out[0] for _ in range(n_runs): input = torch.randn(N, C, device=device, requires_grad=True) target = torch.randint(0, C, size=input.size(), device=device) # forward start = _time() result = F.multilabel_margin_loss(input, target, reduction=reduction) total_fwd_time += _time() - start result = F.multilabel_margin_loss(input, target, reduction=reduction) for _ in range(n_runs): # backward start = _time() result.backward(grad_out, retain_graph=True) total_back_time += _time() - start fwd_avg = total_fwd_time / n_runs bwd_avg = total_back_time / n_runs print( f"input size({N}, {C}), reduction: {reduction}, fwd: {fwd_avg:.2f} (ms), back: {bwd_avg:.2f} (ms)" ) ``` </details> ## master ``` input size(1000, 30), reduction: none, fwd: 0.14 (ms), back: 0.41 (ms) input size(10000, 30), reduction: none, fwd: 1.26 (ms), back: 3.58 (ms) input size(100000, 30), reduction: none, fwd: 13.15 (ms), back: 34.68 (ms) input size(1000, 30), reduction: sum, fwd: 0.14 (ms), back: 0.38 (ms) input size(10000, 30), reduction: sum, fwd: 1.16 (ms), back: 3.53 (ms) input size(100000, 30), reduction: sum, fwd: 13.04 (ms), back: 34.53 (ms) input size(1000, 30), reduction: mean, fwd: 0.14 (ms), back: 0.38 (ms) input size(10000, 30), reduction: mean, fwd: 1.17 (ms), back: 3.52 (ms) input size(100000, 30), reduction: mean, fwd: 13.12 (ms), back: 34.54 (ms) ``` ## this PR ``` input size(1000, 30), reduction: none, fwd: 0.14 (ms), back: 0.35 (ms) input size(10000, 30), reduction: none, fwd: 1.22 (ms), back: 2.98 (ms) input size(100000, 30), reduction: none, fwd: 12.90 (ms), back: 29.32 (ms) input size(1000, 30), reduction: sum, fwd: 0.14 (ms), back: 0.32 (ms) input size(10000, 30), reduction: sum, fwd: 1.16 (ms), back: 2.97 (ms) input size(100000, 30), reduction: sum, fwd: 13.00 (ms), back: 29.17 (ms) input size(1000, 30), reduction: mean, fwd: 0.14 (ms), back: 0.32 (ms) input size(10000, 30), reduction: mean, fwd: 1.17 (ms), back: 2.97 (ms) input size(100000, 30), reduction: mean, fwd: 13.09 (ms), back: 28.91 (ms) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/60708 Reviewed By: saketh-are Differential Revision: D29856579 Pulled By: ngimel fbshipit-source-id: b6bbf27a71e5a04f61779f6fef4ed1c98baa2607
Author
Parents
Loading