Migrates nll_loss_forward from TH to Aten (CUDA) (#60097)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/24610
Aten Umbrella issue https://github.com/pytorch/pytorch/issues/24507
Related to https://github.com/pytorch/pytorch/issues/59765
The performance does not change between this PR and master with the following benchmark script:
<details>
<summary>Benchmark script</summary>
```python
import torch
import torch.nn as nn
import time
torch.manual_seed(0)
def _time():
torch.cuda.synchronize()
MS_PER_SECOND = 1000
return time.perf_counter() * MS_PER_SECOND
device = "cuda"
C = 30
softmax = nn.LogSoftmax(dim=1)
n_runs = 250
for reduction in ["none", "mean", "sum"]:
for N in [100_000, 500_000, 1_000_000]:
fwd_t = 0
bwd_t = 0
data = torch.randn(N, C, device=device)
target = torch.empty(N, dtype=torch.long, device=device).random_(0, C)
loss = nn.NLLLoss(reduction=reduction)
input = softmax(data)
for i in range(n_runs):
t1 = _time()
result = loss(input, target)
t2 = _time()
fwd_t = fwd_t + (t2 - t1)
fwd_avg = fwd_t / n_runs
print(
f"input size({N}, {C}), reduction: {reduction} "
f"forward time is {fwd_avg:.2f} (ms)"
)
print()
```
</details>
## master
```
input size(100000, 30), reduction: none forward time is 0.02 (ms)
input size(500000, 30), reduction: none forward time is 0.08 (ms)
input size(1000000, 30), reduction: none forward time is 0.15 (ms)
input size(100000, 30), reduction: mean forward time is 1.81 (ms)
input size(500000, 30), reduction: mean forward time is 8.24 (ms)
input size(1000000, 30), reduction: mean forward time is 16.46 (ms)
input size(100000, 30), reduction: sum forward time is 1.66 (ms)
input size(500000, 30), reduction: sum forward time is 8.24 (ms)
input size(1000000, 30), reduction: sum forward time is 16.46 (ms)
```
## this PR
```
input size(100000, 30), reduction: none forward time is 0.02 (ms)
input size(500000, 30), reduction: none forward time is 0.08 (ms)
input size(1000000, 30), reduction: none forward time is 0.15 (ms)
input size(100000, 30), reduction: mean forward time is 1.80 (ms)
input size(500000, 30), reduction: mean forward time is 8.24 (ms)
input size(1000000, 30), reduction: mean forward time is 16.46 (ms)
input size(100000, 30), reduction: sum forward time is 1.66 (ms)
input size(500000, 30), reduction: sum forward time is 8.24 (ms)
input size(1000000, 30), reduction: sum forward time is 16.46 (ms)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60097
Reviewed By: mrshenli
Differential Revision: D29303099
Pulled By: ngimel
fbshipit-source-id: fc0d636543a79ea81158d286dcfb84043bec079a