pytorch
182ee879 - symintify nll loss fns (#86915) (#87095)

Commit
2 years ago
symintify nll loss fns (#86915) (#87095) This reverts commit bbd7b38d5580c44ffb4404d431e07bc2316e59d5. Reland https://github.com/pytorch/pytorch/pull/86915 with a fix for python arg parser handing for SymInt and SymIntList. This was uncovered because we are calling directly into python bindings code through test_autocast.py (`torch._C._nn.nll_loss`) without providing a value for the optional symint arg (`ignore_index`). The arg parser constructs the SymInt and SymIntList using the recorded "default_int" or "default_int_list" (schema string parsing) in case a value is not received for an optional argument. Since we weren't handling the symint case properly, the default_int just had a garbage value which was later being used to construct SymInt. Follow up issue for other unhandled parameter types: https://github.com/pytorch/pytorch/issues/87283 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87095 Approved by: https://github.com/ezyang, https://github.com/albanD
Author
Committer
Parents
Loading