pytorch
a4b02a15 - Support registering op returning symint in python (#95240)

Commit
1 year ago
Support registering op returning symint in python (#95240) Running an operator registered in python returning a symint will result in the following error: ``` RuntimeError: Unable to cast Python instance of type <class 'torch.SymInt'> to C++ type 'long' ``` The interaction of 2 things make the issue being triggered: - We use boxed kernel here. For boxed kernel, we need convert py::object to IValue in torch/csrc/autograd/python_variable.cpp pushPyOutToStack . - In the schema parsing code in torch/csrc/jit/frontend/schema_type_parser.cpp SchemaTypeParser::parseFakeAndRealType , if a SymInt is found, we register a Int type instead (not sure why we do this), and register SymInt as the real type. The result is we would convert an SymInt to int in pushPyOutToStack and cause the issue. The fix is to use real type when we convert py::object to IValue. BTW, registering the same op using C++ API does not trigger the issue. ``` TORCH_LIBRARY(clib, m) { m.def("sqsum(SymInt a, SymInt b) -> SymInt", [](SymInt a, SymInt b) -> SymInt { return a * a + b * b; }); } ``` The reason is, the kernel registered in C++ is unboxed kernel and it does not trigger the code path above that converts an py::object to IValue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95240 Approved by: https://github.com/larryliu0820, https://github.com/ezyang
Author
Committer
Parents
Loading