pytorch
1f0223d6 - Fix bug in gaussian_nll_loss (#56469)

Commit
3 years ago
Fix bug in gaussian_nll_loss (#56469) Summary: Fixes https://github.com/pytorch/pytorch/issues/53964. cc albanD almson ## Major changes: - Overhauled the actual loss calculation so that the shapes are now correct (in functional.py) - added the missing doc in nn.functional.rst ## Minor changes (in functional.py): - I removed the previous check on whether input and target were the same shape. This is to allow for broadcasting, say when you have 10 predictions that all have the same target. - I added some comments to explain each shape check in detail. Let me know if these should be shortened/cut. Screenshots of updated docs attached. Let me know what you think, thanks! ## Edit: Description of change of behaviour (affecting BC): The backwards-compatibility is only affected for the `reduction='none'` mode. This was the source of the bug. For tensors with size (N, D), the old returned loss had size (N), as incorrect summation was happening. It will now have size (N, D) as expected. ### Example Define input tensors, all with size (2, 3). `input = torch.tensor([[0., 1., 3.], [2., 4., 0.]], requires_grad=True)` `target = torch.tensor([[1., 4., 2.], [-1., 2., 3.]])` `var = 2*torch.ones(size=(2, 3), requires_grad=True)` Initialise loss with reduction mode 'none'. We expect the returned loss to have the same size as the input tensors, (2, 3). `loss = torch.nn.GaussianNLLLoss(reduction='none')` Old behaviour: `print(loss(input, target, var)) ` `# Gives tensor([3.7897, 6.5397], grad_fn=<MulBackward0>. This has size (2).` New behaviour: `print(loss(input, target, var)) ` `# Gives tensor([[0.5966, 2.5966, 0.5966], [2.5966, 1.3466, 2.5966]], grad_fn=<MulBackward0>)` `# This has the expected size, (2, 3).` To recover the old behaviour, sum along all dimensions except for the 0th: `print(loss(input, target, var).sum(dim=1))` `# Gives tensor([3.7897, 6.5397], grad_fn=<SumBackward1>.` ![doc1](https://user-images.githubusercontent.com/26558092/115391089-f7f47b00-a1d6-11eb-8726-e4da9057aee0.png) ![doc2](https://user-images.githubusercontent.com/26558092/115391094-f925a800-a1d6-11eb-954b-afd187f42bc7.png) Pull Request resolved: https://github.com/pytorch/pytorch/pull/56469 Reviewed By: jbschlosser, agolynski Differential Revision: D27894170 Pulled By: albanD fbshipit-source-id: 197890189c97c22109491c47f469336b5b03a23f
Author
M.L. Croci
Parents
Loading