MAINT Migrates multilabel_margin_loss from THC to ATen (CUDA) (#60708)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/24603
Fixes https://github.com/pytorch/pytorch/issues/24602
<s>The implementation should be exactly the same, so it is strange that the benchmarks show such a significant improvement in this PR.</s>
The benchmarks are now the same.
<details>
<summary>Benchmark script</summary>
```python
from itertools import product
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
torch.manual_seed(0)
MS_PER_SECOND = 1000
def _time():
torch.cuda.synchronize()
return time.perf_counter() * MS_PER_SECOND
device = "cuda"
C = 30
n_runs = 100
reductions = ["none", "sum", "mean"]
Ns = [1_000, 10_000, 100_000]
for reduction, N in product(reductions, Ns):
total_fwd_time = 0
total_back_time = 0
grad_out = torch.randn(N, device=device)
if reduction != "none":
grad_out = grad_out[0]
for _ in range(n_runs):
input = torch.randn(N, C, device=device, requires_grad=True)
target = torch.randint(0, C, size=input.size(), device=device)
# forward
start = _time()
result = F.multilabel_margin_loss(input, target, reduction=reduction)
total_fwd_time += _time() - start
result = F.multilabel_margin_loss(input, target, reduction=reduction)
for _ in range(n_runs):
# backward
start = _time()
result.backward(grad_out, retain_graph=True)
total_back_time += _time() - start
fwd_avg = total_fwd_time / n_runs
bwd_avg = total_back_time / n_runs
print(
f"input size({N}, {C}), reduction: {reduction}, fwd: {fwd_avg:.2f} (ms), back: {bwd_avg:.2f} (ms)"
)
```
</details>
## master
```
input size(1000, 30), reduction: none, fwd: 0.14 (ms), back: 0.41 (ms)
input size(10000, 30), reduction: none, fwd: 1.26 (ms), back: 3.58 (ms)
input size(100000, 30), reduction: none, fwd: 13.15 (ms), back: 34.68 (ms)
input size(1000, 30), reduction: sum, fwd: 0.14 (ms), back: 0.38 (ms)
input size(10000, 30), reduction: sum, fwd: 1.16 (ms), back: 3.53 (ms)
input size(100000, 30), reduction: sum, fwd: 13.04 (ms), back: 34.53 (ms)
input size(1000, 30), reduction: mean, fwd: 0.14 (ms), back: 0.38 (ms)
input size(10000, 30), reduction: mean, fwd: 1.17 (ms), back: 3.52 (ms)
input size(100000, 30), reduction: mean, fwd: 13.12 (ms), back: 34.54 (ms)
```
## this PR
```
input size(1000, 30), reduction: none, fwd: 0.14 (ms), back: 0.35 (ms)
input size(10000, 30), reduction: none, fwd: 1.22 (ms), back: 2.98 (ms)
input size(100000, 30), reduction: none, fwd: 12.90 (ms), back: 29.32 (ms)
input size(1000, 30), reduction: sum, fwd: 0.14 (ms), back: 0.32 (ms)
input size(10000, 30), reduction: sum, fwd: 1.16 (ms), back: 2.97 (ms)
input size(100000, 30), reduction: sum, fwd: 13.00 (ms), back: 29.17 (ms)
input size(1000, 30), reduction: mean, fwd: 0.14 (ms), back: 0.32 (ms)
input size(10000, 30), reduction: mean, fwd: 1.17 (ms), back: 2.97 (ms)
input size(100000, 30), reduction: mean, fwd: 13.09 (ms), back: 28.91 (ms)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60708
Reviewed By: saketh-are
Differential Revision: D29856579
Pulled By: ngimel
fbshipit-source-id: b6bbf27a71e5a04f61779f6fef4ed1c98baa2607