pytorch
c7ed89cf - Migrate `nll_loss2d` from TH to ATen (CPU) (#28304)

Commit
5 years ago
Migrate `nll_loss2d` from TH to ATen (CPU) (#28304) Summary: Added check for indicies in Reduction::None case. ### Benchmark results Note: Due to the size of the input tensors this time the random number generation is responsible for a significant portion of the total time. It is better to look at the individual net time-outputs (which do not include the input preparation). Script used for benchmark.: [nnl_loss2d_benchmark.py](https://gist.github.com/andreaskoepf/5864aa91e243317cb282c1e7fe576e1b) #### WITH PR applied ``` using reduction: none CPU forward 1000 took 7.916500908322632e-05 CPU forward 10000 took 0.0002642290201038122 CPU forward 100000 took 0.003828087996225804 CPU forward 1000000 took 0.037140720000024885 CPU forward 10000000 took 0.33387596398824826 CPU forward TOTAL time 7.218988707987592 using reduction: mean CPU forward 1000 took 9.165197843685746e-05 CPU forward 10000 took 0.0005258890159893781 CPU forward 100000 took 0.0050761590246111155 CPU forward 1000000 took 0.047345594997750595 CPU forward 10000000 took 0.4790863030066248 CPU forward TOTAL time 7.9106070210109465 CPU for- & backward 1000 took 0.0005489500181283802 CPU for- & backward 10000 took 0.0015284279943443835 CPU for- & backward 100000 took 0.015138130984269083 CPU for- & backward 1000000 took 0.15741890601930209 CPU for- & backward 10000000 took 1.6703072849777527 CPU for- & backward TOTAL time 9.555764263990568 using reduction: sum CPU forward 1000 took 8.789298590272665e-05 CPU forward 10000 took 0.000514078012201935 CPU forward 100000 took 0.005135576997417957 CPU forward 1000000 took 0.04715992201818153 CPU forward 10000000 took 0.4821214270195924 CPU forward TOTAL time 7.9119505700073205 CPU for- & backward 1000 took 0.00047759301378391683 CPU for- & backward 10000 took 0.0015945070190355182 CPU for- & backward 100000 took 0.018208994006272405 CPU for- & backward 1000000 took 0.15904426100314595 CPU for- & backward 10000000 took 1.5679037219961174 CPU for- & backward TOTAL time 9.495157692988869 ``` #### WITHOUT original TH impl ``` using reduction: none CPU forward 1000 took 0.0003981560003012419 CPU forward 10000 took 0.0035912430030293763 CPU forward 100000 took 0.035353766987100244 CPU forward 1000000 took 0.3428319719969295 CPU forward 10000000 took 3.364342701010173 CPU forward TOTAL time 11.166179805004504 using reduction: mean CPU forward 1000 took 8.63690220285207e-05 CPU forward 10000 took 0.0004704220045823604 CPU forward 100000 took 0.0045734510058537126 CPU forward 1000000 took 0.046232511987909675 CPU forward 10000000 took 0.4191019559802953 CPU forward TOTAL time 7.846049971994944 CPU for- & backward 1000 took 0.0005974550149403512 CPU for- & backward 10000 took 0.0014057719963602722 CPU for- & backward 100000 took 0.013776941981632262 CPU for- & backward 1000000 took 0.13876214998890646 CPU for- & backward 10000000 took 1.3666698939923663 CPU for- & backward TOTAL time 9.10526105100871 using reduction: sum CPU forward 1000 took 7.598899537697434e-05 CPU forward 10000 took 0.00046885499614290893 CPU forward 100000 took 0.0044489419960882515 CPU forward 1000000 took 0.04495517900795676 CPU forward 10000000 took 0.418376043002354 CPU forward TOTAL time 7.789334400993539 CPU for- & backward 1000 took 0.0004464260127861053 CPU for- & backward 10000 took 0.0017732900159899145 CPU for- & backward 100000 took 0.01626713399309665 CPU for- & backward 1000000 took 0.11790941300569102 CPU for- & backward 10000000 took 1.4346664609911386 CPU for- & backward TOTAL time 9.294745502003934 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/28304 Differential Revision: D18350157 Pulled By: ezyang fbshipit-source-id: e9437debe51386a483f4265193c475cdc90b28e4
Author
Parents
Loading