ENH Migrate nll_loss2d from THC to ATen (#62826)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/24608
Fixes https://github.com/pytorch/pytorch/issues/24607
With the following benchmark, the backward pass runs a little slower. This is strange since the implementation should be exactly the same.
<details>
<summary>Benchmark script</summary>
```python
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
torch.manual_seed(0)
MS_PER_SECOND = 1000
def _time():
torch.cuda.synchronize()
return time.perf_counter() * MS_PER_SECOND
device = "cuda"
C = 3
n_runs = 30
reductions = ["none", "sum", "mean"]
Ns = [128, 256, 512]
Hs = [128, 256, 512]
for reduction, N, H in product(reductions, Ns, Hs):
total_fwd_time = 0
total_back_time = 0
if reduction == "none":
grad_out = torch.randn(N, H, H, device=device)
else:
grad_out = torch.randn(1)[0]
for _ in range(n_runs):
input = torch.randn(N, C, H, H, device=device, requires_grad=True)
target = torch.rand(N, H, H, device=device).mul(3).floor().long()
# forward
start = _time()
result = F.nll_loss(input, target, reduction=reduction)
total_fwd_time += _time() - start
result = F.nll_loss(input, target, reduction=reduction)
for _ in range(n_runs):
# backward
start = _time()
result.backward(grad_out, retain_graph=True)
total_back_time += _time() - start
fwd_avg = total_fwd_time / n_runs
bwd_avg = total_back_time / n_runs
print(
f"input size({N}, {C}, {H}, {H}), reduction: {reduction}, fwd: {fwd_avg:.2f} (ms), back: {bwd_avg:.2f} (ms)"
)
```
</details>
<details>
<summary>master results</summary>
```
input size(128, 3, 128, 128), reduction: none, fwd: 0.34 (ms), back: 0.57 (ms)
input size(128, 3, 256, 256), reduction: none, fwd: 2.56 (ms), back: 3.85 (ms)
input size(128, 3, 512, 512), reduction: none, fwd: 14.54 (ms), back: 16.62 (ms)
input size(256, 3, 128, 128), reduction: none, fwd: 1.26 (ms), back: 1.78 (ms)
input size(256, 3, 256, 256), reduction: none, fwd: 7.07 (ms), back: 8.22 (ms)
input size(256, 3, 512, 512), reduction: none, fwd: 29.38 (ms), back: 33.29 (ms)
input size(512, 3, 128, 128), reduction: none, fwd: 3.41 (ms), back: 4.05 (ms)
input size(512, 3, 256, 256), reduction: none, fwd: 14.32 (ms), back: 16.46 (ms)
input size(512, 3, 512, 512), reduction: none, fwd: 59.20 (ms), back: 66.68 (ms)
input size(128, 3, 128, 128), reduction: sum, fwd: 0.08 (ms), back: 0.21 (ms)
input size(128, 3, 256, 256), reduction: sum, fwd: 0.21 (ms), back: 0.73 (ms)
input size(128, 3, 512, 512), reduction: sum, fwd: 0.82 (ms), back: 2.86 (ms)
input size(256, 3, 128, 128), reduction: sum, fwd: 0.12 (ms), back: 0.39 (ms)
input size(256, 3, 256, 256), reduction: sum, fwd: 0.42 (ms), back: 1.45 (ms)
input size(256, 3, 512, 512), reduction: sum, fwd: 1.53 (ms), back: 5.66 (ms)
input size(512, 3, 128, 128), reduction: sum, fwd: 0.21 (ms), back: 0.74 (ms)
input size(512, 3, 256, 256), reduction: sum, fwd: 0.78 (ms), back: 2.86 (ms)
input size(512, 3, 512, 512), reduction: sum, fwd: 2.98 (ms), back: 11.23 (ms)
input size(128, 3, 128, 128), reduction: mean, fwd: 0.07 (ms), back: 0.21 (ms)
input size(128, 3, 256, 256), reduction: mean, fwd: 0.21 (ms), back: 0.73 (ms)
input size(128, 3, 512, 512), reduction: mean, fwd: 0.82 (ms), back: 2.86 (ms)
input size(256, 3, 128, 128), reduction: mean, fwd: 0.13 (ms), back: 0.39 (ms)
input size(256, 3, 256, 256), reduction: mean, fwd: 0.42 (ms), back: 1.45 (ms)
input size(256, 3, 512, 512), reduction: mean, fwd: 1.54 (ms), back: 5.65 (ms)
input size(512, 3, 128, 128), reduction: mean, fwd: 0.22 (ms), back: 0.74 (ms)
input size(512, 3, 256, 256), reduction: mean, fwd: 0.78 (ms), back: 2.87 (ms)
input size(512, 3, 512, 512), reduction: mean, fwd: 2.98 (ms), back: 11.23 (ms)
```
</details>
<details>
<summary>PR results</summary>
```
input size(128, 3, 128, 128), reduction: none, fwd: 0.33 (ms), back: 0.59 (ms)
input size(128, 3, 256, 256), reduction: none, fwd: 2.51 (ms), back: 3.92 (ms)
input size(128, 3, 512, 512), reduction: none, fwd: 14.52 (ms), back: 17.05 (ms)
input size(256, 3, 128, 128), reduction: none, fwd: 1.23 (ms), back: 1.85 (ms)
input size(256, 3, 256, 256), reduction: none, fwd: 7.07 (ms), back: 8.45 (ms)
input size(256, 3, 512, 512), reduction: none, fwd: 29.39 (ms), back: 34.21 (ms)
input size(512, 3, 128, 128), reduction: none, fwd: 3.40 (ms), back: 4.18 (ms)
input size(512, 3, 256, 256), reduction: none, fwd: 14.33 (ms), back: 16.90 (ms)
input size(512, 3, 512, 512), reduction: none, fwd: 59.04 (ms), back: 68.36 (ms)
input size(128, 3, 128, 128), reduction: sum, fwd: 0.07 (ms), back: 0.25 (ms)
input size(128, 3, 256, 256), reduction: sum, fwd: 0.21 (ms), back: 0.86 (ms)
input size(128, 3, 512, 512), reduction: sum, fwd: 0.82 (ms), back: 3.33 (ms)
input size(256, 3, 128, 128), reduction: sum, fwd: 0.12 (ms), back: 0.46 (ms)
input size(256, 3, 256, 256), reduction: sum, fwd: 0.42 (ms), back: 1.70 (ms)
input size(256, 3, 512, 512), reduction: sum, fwd: 1.53 (ms), back: 6.58 (ms)
input size(512, 3, 128, 128), reduction: sum, fwd: 0.21 (ms), back: 0.87 (ms)
input size(512, 3, 256, 256), reduction: sum, fwd: 0.78 (ms), back: 3.34 (ms)
input size(512, 3, 512, 512), reduction: sum, fwd: 2.98 (ms), back: 13.07 (ms)
input size(128, 3, 128, 128), reduction: mean, fwd: 0.07 (ms), back: 0.26 (ms)
input size(128, 3, 256, 256), reduction: mean, fwd: 0.21 (ms), back: 0.86 (ms)
input size(128, 3, 512, 512), reduction: mean, fwd: 0.82 (ms), back: 3.34 (ms)
input size(256, 3, 128, 128), reduction: mean, fwd: 0.12 (ms), back: 0.46 (ms)
input size(256, 3, 256, 256), reduction: mean, fwd: 0.42 (ms), back: 1.72 (ms)
input size(256, 3, 512, 512), reduction: mean, fwd: 1.53 (ms), back: 6.60 (ms)
input size(512, 3, 128, 128), reduction: mean, fwd: 0.21 (ms), back: 0.87 (ms)
input size(512, 3, 256, 256), reduction: mean, fwd: 0.78 (ms), back: 3.33 (ms)
input size(512, 3, 512, 512), reduction: mean, fwd: 2.98 (ms), back: 13.07 (ms)
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62826
Reviewed By: bdhirsh
Differential Revision: D30282279
Pulled By: ngimel
fbshipit-source-id: 4aa0ff3f8af0632957417931d332ec486a12b52d