pytorch
e86476f7 - Huber loss (#50553)

Commit
3 years ago
Huber loss (#50553) Summary: Fixes https://github.com/pytorch/pytorch/issues/48595. ## Background This PR implements HuberLoss, which differs from SmoothL1Loss by a factor of beta. The current implementation does not share logic between the two. Feedback is welcome for the optimal way to minimize code duplication while remaining performant. I've done some early [benchmarking](https://pytorch.org/tutorials/recipes/recipes/benchmark.html#collecting-instruction-counts-with-callgrind) with Huber calling in to the Smooth L1 kernel and scaling afterwards; for the simple test case I used, instruction counts are as follows: ``` Huber loss calls dedicated Huber kernel: 2,795,300 Huber loss calls Smooth L1 kernel and scales afterwards: 4,523,612 ``` With these numbers, instruction counts are ~62% higher when using the pre-existing Smooth L1 kernel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/50553 Test Plan: ``` python test/test_nn.py TestNN.test_HuberLoss python test/test_nn.py TestNN.test_HuberLoss_delta python test/test_nn.py TestNN.test_huber_loss_invalid_delta python test/test_nn.py TestNNDeviceTypeCPU.test_smooth_l1_loss_vs_huber_loss_cpu python test/test_nn.py TestNNDeviceTypeCUDA.test_smooth_l1_loss_vs_huber_loss_cuda python test/test_nn.py TestNNDeviceTypeCPU.test_invalid_reduction_strings_cpu python test/test_nn.py TestNNDeviceTypeCUDA.test_invalid_reduction_strings_cuda python test/test_nn.py TestNN.test_loss_equal_input_target_shape python test/test_nn.py TestNN.test_pointwise_loss_broadcast python test/test_overrides.py python test/test_jit.py TestJitGeneratedFunctional.test_nn_huber_loss python test/test_type_hints.py python test/test_cpp_api_parity.py build/bin/test_api ``` ## Documentation <img width="677" alt="Screen Shot 2021-01-14 at 4 25 08 PM" src="https://user-images.githubusercontent.com/75754324/104651224-5a445980-5685-11eb-884b-14ea517958c2.png"> <img width="677" alt="Screen Shot 2021-01-14 at 4 24 35 PM" src="https://user-images.githubusercontent.com/75754324/104651190-4e589780-5685-11eb-974d-8c63a89c050e.png"> <img width="661" alt="Screen Shot 2021-01-14 at 4 24 45 PM" src="https://user-images.githubusercontent.com/75754324/104651198-50225b00-5685-11eb-958e-136b36f6f8a8.png"> <img width="869" alt="Screen Shot 2021-01-14 at 4 25 27 PM" src="https://user-images.githubusercontent.com/75754324/104651208-53b5e200-5685-11eb-9fe4-5ff433aa13c5.png"> <img width="862" alt="Screen Shot 2021-01-14 at 4 25 48 PM" src="https://user-images.githubusercontent.com/75754324/104651209-53b5e200-5685-11eb-8051-b0cfddcb07d3.png"> Reviewed By: H-Huang Differential Revision: D26734071 Pulled By: jbschlosser fbshipit-source-id: c98c1b5f32a16f7a2a4e04bdce678080eceed5d5
Author
Parents
Loading