pytorch
a2419638 - [nll_loss] Avoid unnecessary type casts (#86086)

Commit
2 years ago
[nll_loss] Avoid unnecessary type casts (#86086) follow-up #85395 `AT_DISPATCH_NLL_LOSS_INDEX_TYPES` should not be removed in favor of #59765 and there's a testcase https://github.com/pytorch/pytorch/blob/99ca25e6eb8299f31824bdbaf62f16f8a8db458d/test/test_nn.py#L16832 Besides the dispatcher, I wanted to sanity check `int64_t ignore_index` because `int64_t` can be inappropriate considering that `target` can be `Byte`. However, given that the default value is -100 as in https://github.com/pytorch/pytorch/blob/0a75c42f36c0e50a22c06fa65478df53d7d420c4/aten/src/ATen/native/native_functions.yaml#L9949 it's not easy to add a check while keeping the backward compatibility. Thus I decided to not add a check. cc @lezcano @t-vi Pull Request resolved: https://github.com/pytorch/pytorch/pull/86086 Approved by: https://github.com/lezcano
Author
Committer
Parents
Loading