Migrates nll_loss_backward from TH to Aten (CUDA) (#60299)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/24609
Aten Umbrella issue https://github.com/pytorch/pytorch/issues/24507
Related to https://github.com/pytorch/pytorch/issues/59765
There are no performance differences when running the following benchmark:
<details>
<summary>Benchmark script</summary>
```python
import torch
import torch.nn as nn
import time
torch.manual_seed(0)
def _time():
torch.cuda.synchronize()
MS_PER_SECOND = 1000
return time.perf_counter() * MS_PER_SECOND
device = "cuda"
C = 30
softmax = nn.LogSoftmax(dim=1)
n_runs = 250
for reduction in ["none", "mean", "sum"]:
for N in [100_000, 500_000, 1_000_000]:
elapsed = 0
for i in range(n_runs):
data = torch.randn(N, C, device=device, requires_grad=True)
target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
loss = nn.NLLLoss(reduction=reduction)
input = softmax(data)
result = loss(input, target)
if reduction == "none":
gradient = torch.randn(N, device=device)
else:
gradient = torch.randn(1, device=device).squeeze()
t1 = _time()
result.backward(gradient)
t2 = _time()
elapsed = elapsed + (t2 - t1)
elapsed_avg = elapsed / n_runs
print(
f"input size({N}, {C}), reduction: {reduction} "
f"elapsed time is {elapsed_avg:.2f} (ms)"
)
print()
```
</details>
## master
```
input size(100000, 30), reduction: none elapsed time is 0.19 (ms)
input size(500000, 30), reduction: none elapsed time is 0.83 (ms)
input size(1000000, 30), reduction: none elapsed time is 1.66 (ms)
input size(100000, 30), reduction: mean elapsed time is 1.50 (ms)
input size(500000, 30), reduction: mean elapsed time is 7.19 (ms)
input size(1000000, 30), reduction: mean elapsed time is 14.35 (ms)
input size(100000, 30), reduction: sum elapsed time is 1.49 (ms)
input size(500000, 30), reduction: sum elapsed time is 7.17 (ms)
input size(1000000, 30), reduction: sum elapsed time is 14.21 (ms)
```
## this PR
```
input size(100000, 30), reduction: none elapsed time is 0.19 (ms)
input size(500000, 30), reduction: none elapsed time is 0.83 (ms)
input size(1000000, 30), reduction: none elapsed time is 1.66 (ms)
input size(100000, 30), reduction: mean elapsed time is 1.48 (ms)
input size(500000, 30), reduction: mean elapsed time is 7.16 (ms)
input size(1000000, 30), reduction: mean elapsed time is 14.29 (ms)
input size(100000, 30), reduction: sum elapsed time is 1.49 (ms)
input size(500000, 30), reduction: sum elapsed time is 7.15 (ms)
input size(1000000, 30), reduction: sum elapsed time is 14.18 (ms)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60299
Reviewed By: albanD
Differential Revision: D29287613
Pulled By: ngimel
fbshipit-source-id: 21e15f2c518087e9fb797a379e1e0a3508c98509