Migrate soft_margin_loss from the TH to Aten (CUDA+CPU) (#28135)
Summary:
Fix: https://github.com/pytorch/pytorch/issues/24631, https://github.com/pytorch/pytorch/issues/24632, https://github.com/pytorch/pytorch/issues/24764, https://github.com/pytorch/pytorch/issues/24765
Port of TH SoftMarginCriterion to ATen using un-fused tensor operators but with custom backward code. This is a follow-up/fixc of reverted PR https://github.com/pytorch/pytorch/issues/27673.
Benchmark results:
CPU became faster, GPU slower. To reach previous TH perf probably manual fusion is necessary.
### WITH patch
```
CPU warmup 1000 took 7.997200009413064e-05
CPU warmup 10000 took 0.0008116499957395718
CPU warmup 100000 took 0.0012691459996858612
CPU warmup TOTAL time 0.0021982479956932366
CPU forward 1000 took 7.320100849028677e-05
CPU forward 10000 took 0.00015837099635973573
CPU forward 100000 took 0.0010471990099176764
CPU forward 1000000 took 0.01238470000680536
CPU forward 10000000 took 0.12747182900784537
CPU forward 100000000 took 1.2076255190040683
CPU forward TOTAL time 1.3488940890092636
CPU for- & backward 1000 took 0.00032587299938313663
CPU for- & backward 10000 took 0.0006926299975020811
CPU for- & backward 100000 took 0.002146183993318118
CPU for- & backward 1000000 took 0.019158899012836628
CPU for- & backward 10000000 took 0.2957490350090666
CPU for- & backward 100000000 took 1.7630806300003314
CPU for- & backward TOTAL time 2.081367089995183
GPU warmup 1000 took 0.0004558280052151531
GPU warmup 10000 took 0.0002567449992056936
GPU warmup 100000 took 0.0001593509950907901
GPU warmup TOTAL time 0.0009442300070077181
GPU forward 1000 took 0.00015061900194268674
GPU forward 10000 took 0.00015258099301718175
GPU forward 100000 took 0.00015409699699375778
GPU forward 1000000 took 0.0008183339959941804
GPU forward 10000000 took 0.004424853003001772
GPU forward 100000000 took 0.04356115800328553
GPU forward TOTAL time 0.04938192600093316
GPU for- & backward 1000 took 0.0008062430133577436
GPU for- & backward 10000 took 0.0006074949924368411
GPU for- & backward 100000 took 0.0007091690058587119
GPU for- & backward 1000000 took 0.001022183001623489
GPU for- & backward 10000000 took 0.009945805999450386
GPU for- & backward 100000000 took 0.0944173600000795
GPU for- & backward TOTAL time 0.28060428200114984
```
### WITHOUT patch
```
CPU warmup 1000 took 6.394000956788659e-05
CPU warmup 10000 took 0.00038220599526539445
CPU warmup 100000 took 0.0034939230099553242
CPU warmup TOTAL time 0.003981974994530901
CPU forward 1000 took 4.7855006414465606e-05
CPU forward 10000 took 0.000347569992300123
CPU forward 100000 took 0.003367935001733713
CPU forward 1000000 took 0.03605044000141788
CPU forward 10000000 took 0.35935167300340254
CPU forward 100000000 took 3.630371332008508
CPU forward TOTAL time 4.029640004009707
CPU for- & backward 1000 took 0.00028494100843090564
CPU for- & backward 10000 took 0.0006738200027029961
CPU for- & backward 100000 took 0.0051178760040784255
CPU for- & backward 1000000 took 0.04925115800870117
CPU for- & backward 10000000 took 0.7172313440096332
CPU for- & backward 100000000 took 5.441953932997421
CPU for- & backward TOTAL time 6.21466830400459
GPU warmup 1000 took 0.001803738996386528
GPU warmup 10000 took 0.00041877900366671383
GPU warmup 100000 took 0.0003870719956466928
GPU warmup TOTAL time 0.0026561370032140985
GPU forward 1000 took 0.00037833399255760014
GPU forward 10000 took 0.00038825398951303214
GPU forward 100000 took 0.0003841099969577044
GPU forward 1000000 took 0.0007090550061548129
GPU forward 10000000 took 0.0016171559982467443
GPU forward 100000000 took 0.013463679002597928
GPU forward TOTAL time 0.017010531009873375
GPU for- & backward 1000 took 0.0007374050037469715
GPU for- & backward 10000 took 0.0006343529967125505
GPU for- & backward 100000 took 0.0006375070079229772
GPU for- & backward 1000000 took 0.0007550300069851801
GPU for- & backward 10000000 took 0.002672752001672052
GPU for- & backward 100000000 took 0.023170708998804912
GPU for- & backward TOTAL time 0.20251446698966902
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28135
Differential Revision: D18001447
Pulled By: VitalyFedyunin
fbshipit-source-id: ad90dc1cca42dcaf3ea9e17e4f8fd79cee0a293e