pytorch
07b00fc3 - ENH Migrate nll_loss2d from THC to ATen (#62826)

Commit
3 years ago
ENH Migrate nll_loss2d from THC to ATen (#62826) Summary: Fixes https://github.com/pytorch/pytorch/issues/24608 Fixes https://github.com/pytorch/pytorch/issues/24607 With the following benchmark, the backward pass runs a little slower. This is strange since the implementation should be exactly the same. <details> <summary>Benchmark script</summary> ```python from itertools import product import torch import torch.nn as nn import torch.nn.functional as F import time torch.manual_seed(0) MS_PER_SECOND = 1000 def _time(): torch.cuda.synchronize() return time.perf_counter() * MS_PER_SECOND device = "cuda" C = 3 n_runs = 30 reductions = ["none", "sum", "mean"] Ns = [128, 256, 512] Hs = [128, 256, 512] for reduction, N, H in product(reductions, Ns, Hs): total_fwd_time = 0 total_back_time = 0 if reduction == "none": grad_out = torch.randn(N, H, H, device=device) else: grad_out = torch.randn(1)[0] for _ in range(n_runs): input = torch.randn(N, C, H, H, device=device, requires_grad=True) target = torch.rand(N, H, H, device=device).mul(3).floor().long() # forward start = _time() result = F.nll_loss(input, target, reduction=reduction) total_fwd_time += _time() - start result = F.nll_loss(input, target, reduction=reduction) for _ in range(n_runs): # backward start = _time() result.backward(grad_out, retain_graph=True) total_back_time += _time() - start fwd_avg = total_fwd_time / n_runs bwd_avg = total_back_time / n_runs print( f"input size({N}, {C}, {H}, {H}), reduction: {reduction}, fwd: {fwd_avg:.2f} (ms), back: {bwd_avg:.2f} (ms)" ) ``` </details> <details> <summary>master results</summary> ``` input size(128, 3, 128, 128), reduction: none, fwd: 0.34 (ms), back: 0.57 (ms) input size(128, 3, 256, 256), reduction: none, fwd: 2.56 (ms), back: 3.85 (ms) input size(128, 3, 512, 512), reduction: none, fwd: 14.54 (ms), back: 16.62 (ms) input size(256, 3, 128, 128), reduction: none, fwd: 1.26 (ms), back: 1.78 (ms) input size(256, 3, 256, 256), reduction: none, fwd: 7.07 (ms), back: 8.22 (ms) input size(256, 3, 512, 512), reduction: none, fwd: 29.38 (ms), back: 33.29 (ms) input size(512, 3, 128, 128), reduction: none, fwd: 3.41 (ms), back: 4.05 (ms) input size(512, 3, 256, 256), reduction: none, fwd: 14.32 (ms), back: 16.46 (ms) input size(512, 3, 512, 512), reduction: none, fwd: 59.20 (ms), back: 66.68 (ms) input size(128, 3, 128, 128), reduction: sum, fwd: 0.08 (ms), back: 0.21 (ms) input size(128, 3, 256, 256), reduction: sum, fwd: 0.21 (ms), back: 0.73 (ms) input size(128, 3, 512, 512), reduction: sum, fwd: 0.82 (ms), back: 2.86 (ms) input size(256, 3, 128, 128), reduction: sum, fwd: 0.12 (ms), back: 0.39 (ms) input size(256, 3, 256, 256), reduction: sum, fwd: 0.42 (ms), back: 1.45 (ms) input size(256, 3, 512, 512), reduction: sum, fwd: 1.53 (ms), back: 5.66 (ms) input size(512, 3, 128, 128), reduction: sum, fwd: 0.21 (ms), back: 0.74 (ms) input size(512, 3, 256, 256), reduction: sum, fwd: 0.78 (ms), back: 2.86 (ms) input size(512, 3, 512, 512), reduction: sum, fwd: 2.98 (ms), back: 11.23 (ms) input size(128, 3, 128, 128), reduction: mean, fwd: 0.07 (ms), back: 0.21 (ms) input size(128, 3, 256, 256), reduction: mean, fwd: 0.21 (ms), back: 0.73 (ms) input size(128, 3, 512, 512), reduction: mean, fwd: 0.82 (ms), back: 2.86 (ms) input size(256, 3, 128, 128), reduction: mean, fwd: 0.13 (ms), back: 0.39 (ms) input size(256, 3, 256, 256), reduction: mean, fwd: 0.42 (ms), back: 1.45 (ms) input size(256, 3, 512, 512), reduction: mean, fwd: 1.54 (ms), back: 5.65 (ms) input size(512, 3, 128, 128), reduction: mean, fwd: 0.22 (ms), back: 0.74 (ms) input size(512, 3, 256, 256), reduction: mean, fwd: 0.78 (ms), back: 2.87 (ms) input size(512, 3, 512, 512), reduction: mean, fwd: 2.98 (ms), back: 11.23 (ms) ``` </details> <details> <summary>PR results</summary> ``` input size(128, 3, 128, 128), reduction: none, fwd: 0.33 (ms), back: 0.59 (ms) input size(128, 3, 256, 256), reduction: none, fwd: 2.51 (ms), back: 3.92 (ms) input size(128, 3, 512, 512), reduction: none, fwd: 14.52 (ms), back: 17.05 (ms) input size(256, 3, 128, 128), reduction: none, fwd: 1.23 (ms), back: 1.85 (ms) input size(256, 3, 256, 256), reduction: none, fwd: 7.07 (ms), back: 8.45 (ms) input size(256, 3, 512, 512), reduction: none, fwd: 29.39 (ms), back: 34.21 (ms) input size(512, 3, 128, 128), reduction: none, fwd: 3.40 (ms), back: 4.18 (ms) input size(512, 3, 256, 256), reduction: none, fwd: 14.33 (ms), back: 16.90 (ms) input size(512, 3, 512, 512), reduction: none, fwd: 59.04 (ms), back: 68.36 (ms) input size(128, 3, 128, 128), reduction: sum, fwd: 0.07 (ms), back: 0.25 (ms) input size(128, 3, 256, 256), reduction: sum, fwd: 0.21 (ms), back: 0.86 (ms) input size(128, 3, 512, 512), reduction: sum, fwd: 0.82 (ms), back: 3.33 (ms) input size(256, 3, 128, 128), reduction: sum, fwd: 0.12 (ms), back: 0.46 (ms) input size(256, 3, 256, 256), reduction: sum, fwd: 0.42 (ms), back: 1.70 (ms) input size(256, 3, 512, 512), reduction: sum, fwd: 1.53 (ms), back: 6.58 (ms) input size(512, 3, 128, 128), reduction: sum, fwd: 0.21 (ms), back: 0.87 (ms) input size(512, 3, 256, 256), reduction: sum, fwd: 0.78 (ms), back: 3.34 (ms) input size(512, 3, 512, 512), reduction: sum, fwd: 2.98 (ms), back: 13.07 (ms) input size(128, 3, 128, 128), reduction: mean, fwd: 0.07 (ms), back: 0.26 (ms) input size(128, 3, 256, 256), reduction: mean, fwd: 0.21 (ms), back: 0.86 (ms) input size(128, 3, 512, 512), reduction: mean, fwd: 0.82 (ms), back: 3.34 (ms) input size(256, 3, 128, 128), reduction: mean, fwd: 0.12 (ms), back: 0.46 (ms) input size(256, 3, 256, 256), reduction: mean, fwd: 0.42 (ms), back: 1.72 (ms) input size(256, 3, 512, 512), reduction: mean, fwd: 1.53 (ms), back: 6.60 (ms) input size(512, 3, 128, 128), reduction: mean, fwd: 0.21 (ms), back: 0.87 (ms) input size(512, 3, 256, 256), reduction: mean, fwd: 0.78 (ms), back: 3.33 (ms) input size(512, 3, 512, 512), reduction: mean, fwd: 2.98 (ms), back: 13.07 (ms) ``` </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/62826 Reviewed By: bdhirsh Differential Revision: D30282279 Pulled By: ngimel fbshipit-source-id: 4aa0ff3f8af0632957417931d332ec486a12b52d
Author
Parents
Loading