pytorch
2b0eddb0 - [Static Runtime] Implement prim::isinstance and prim::TypeCheck (#61783)

Commit
3 years ago
[Static Runtime] Implement prim::isinstance and prim::TypeCheck (#61783) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61783 Implement two new prim operators for static runtime: `isinstance` and `TypeCheck`. `isinstance` is very straightforward, but there were a few wrinkles with implementing `TypeCheck`: 1. There is no way to directly generate `TypeCheck` nodes from TorchScript, they are generated by the JIT at runtime. This makes testing a little difficult. I had to make some modifications to `testStaticRuntime` to allow for the use of IR and TorchScript tests. 2. The behavior of `prim::TypeCheck` as implemented here does not match up 1:1 with the version implemented in the interpreter! This is because grad mode is disabled in static runtime. Here's an example. IR is the same as the one included in this test, but with `requires_grad == 1` ``` graph(%a.1 : Tensor, %b.1 : Tensor): %t0 : Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu, requires_grad=1), Float(3, 3, strides=[3, 1])]](%a.1, %b.1) return (%t0, %t1, %type_matched) ``` And in the test setup: ``` auto a = at::zeros({2, 2}, at::kFloat); a.to(at::kCPU); a.set_requires_grad(true); auto b = at::ones({3, 3}, at::kFloat); std::vector<IValue> args_correct = {a, b}; // prim::TypeCheck should be true with args_correct, // but we get false when using static runtime! ``` Reviewed By: hlu1 Differential Revision: D29743862 fbshipit-source-id: db1788f0f5de42bab42602e8cc24eee04cbcc280
Author
Mike Iovine
Parents
Loading