pytorch
368e364c - [MPS] Fix gradient issues with NLL and Smooth_L1 loss ops (#94226)

Commit
1 year ago
[MPS] Fix gradient issues with NLL and Smooth_L1 loss ops (#94226) - Fix correctness issues with nll_loss_backward(), smooth_l1_loss_backward() and cross_entropy_backward() by taking grad_output into account when computing those loss ops - Add numel()==0 check to prevent crashes - Clean up and formatting Pull Request resolved: https://github.com/pytorch/pytorch/pull/94226 Approved by: https://github.com/kulinseth
Author
Committer
Parents
Loading