pytorch
53c2fc65 - Abate spurious resize warnings in `MultiMarginLoss` on CUDA

Commit
4 years ago
Abate spurious resize warnings in `MultiMarginLoss` on CUDA The current `multi_margin_loss_cuda_out` implementation has a few cases where an output is intended to have shape `{0}` before resizing but is improperly set before this step, triggering `UserWarning`s. This PR changes the initial `at::empty` call to avoid this warning and alters the current resizing logic to match that of the CPU version to avoid the warning when calling `at::sum_out`. The original issue was reported from the user forum by @hadaev8 [here](https://discuss.pytorch.org/t/warning-spam-then-using-multimarginloss/147492). CC @ptrblck Original repro provided by @hadaev8: ``` import torch from torch import nn from torch.nn import functional as F bs = 56 model = nn.Linear(128, 22).cuda() loss = nn.MultiMarginLoss() x = torch.rand((bs, 128)).cuda() targets = torch.randint(22, (bs,)).cuda() out = model(x) print(targets.shape) print(out.shape) loss(out, targets) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/75000 Approved by: https://github.com/ngimel
Author
eqy eqy
Committer
Parents
Loading