pytorch
f6eb8117 - Add RefineTypes JIT pass for Tuple (#76919)

Commit
2 years ago
Add RefineTypes JIT pass for Tuple (#76919) Consider the following JIT graph, where the type of `%a` and `%b` are out of sync with tuple `%c`. Before: ``` graph(%a : Float(123), %b : Float(4, 5, 6)): c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b) return (%c) ``` After: ``` graph(%a : Float(123), %b : Float(4, 5, 6)): c : (Float(123), Float(4, 5, 6)) = prim::TupleConstruct(%a, %b) return (%c) ``` This PR adds a pass `RefineTypes(...)` to update all such instances with the correct type. This is also available via Python by using `torch._C._jit_pass_refine_types(...)`. A unit test has been added for unnamed tuples, but no test exists for `NamedTuple` (though it was tested manually) since it isn't supported by the parser: ``` RuntimeError: unknown type specifier: graph(%a : Float(123), %b : Float(4, 5, 6)): %c : NamedTuple(Tensor : Tuple, Tensor : Tuple) = prim::TupleConstruct(%a, %b) ~~~~~~~~~~ <--- HERE return (%c) ``` cc: @ke1337 @antoniojkim @wconstab @eellison Pull Request resolved: https://github.com/pytorch/pytorch/pull/76919 Approved by: https://github.com/eellison
Author
Committer
Parents
Loading