pytorch
7acf0c6d - [PyTorch Edge][type] Add type support for NamedTuple custom class (export) (#62612)

Commit
3 years ago
[PyTorch Edge][type] Add type support for NamedTuple custom class (export) (#62612) Summary: Add type support for namedtule custom class. For the namedtuple type, it will deserailize to the following format in string ``` "qualified_named[ NamedTuple, [ [filed_name_1, field_type_1], [filed_name_2, field_type_2] ] ]" ``` If it's nested, it will be ``` "__torch__.A[ NamedTuple, [ [field_name_a, __torch__.B [ NamedTuple, [ [field_name_b, __torch__.C [ NamedTuple, [ [field_name_c_1, Tensor], [field_name_c_2, Tuple[Tensor, Tensor]], ] ] ] ] ] ] ] ] " ``` The nametuple type includes both `collection` and `typing`. ``` from typing import NamedTuple from collections import namedtuple ``` It will be a forward incompatible change. However this type is never supported and exported before and we don't have a proper way to backport it. The optimum solution to ship this change is probably 1. Update the change for import without the change to export. So the runtime can read the new format, but no new format will be exported. 2. Update the change to export the new type. So runtime can export new format. For the following example: ``` class Foo(NamedTuple): id: torch.Tensor class Bar(torch.nn.Module): def __init__(self): super(Bar, self).__init__() self.foo = Foo(torch.tensor(1)) def forward(self, a: torch.Tensor): self.foo = Foo(a) return self.foo ``` The new bytecode.pkl will be ``` (6, ('__torch__.mobile.test_lite_script_type.MyTestModule.forward', (('instructions', (('STOREN', 1, 2), ('DROPR', 1, 0), ('MOVE', 2, 0), ('LIST_CONSTRUCT', 0, 1), ('NAMED_TUPLE_CONSTRUCT', 1, 1), ('RET', 0, 0))), ('operators', ()), ('constants', ()), ('types', ('List[Tensor]', '__torch__.mobile.test_lite_script_type.myNamedTuple[NamedTuple, [[a, ' 'List[Tensor]]]]')), ('register_size', 2)), (('arguments', ((('name', 'self'), ('type', '__torch__.mobile.test_lite_script_type.MyTestModule'), ('default_value', None)), (('name', 'a'), ('type', 'Tensor'), ('default_value', None)))), ('returns', ((('name', ''), ('type', '__torch__.mobile.test_lite_script_type.myNamedTuple'), ('default_value', None)),))))) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62612 ghstack-source-id: 141485500 Test Plan: fb: 1. Add a simple unittest to test NamedTuple custom class 2. Use following cpp code (D30271153) ``` TEST(LiteTrainerTest, CustomOp) { std::string jit_model = "/home/chenlai/local/notebooks/ads_dper_fl_model_282250609.pt"; Module jit_m = load(jit_model); jit_m.eval(); torch::jit::Module module_freeze = freeze(jit_m); IValue tuple = c10::ivalue::Tuple::create({1 * torch::ones({10, 1034}), 3 * torch::ones({10, 1034})}); std::vector<IValue> inputs_1{tuple}; auto jit_output = jit_m.forward(inputs_1); jit_output.dump(); std::stringstream ss; jit_m._save_for_mobile(ss); jit_m._save_for_mobile("/home/chenlai/local/notebooks/tmp/tmp.ptl"); torch::jit::mobile::Module mobile_m = _load_for_mobile(ss); auto mobile_output = mobile_m.forward(inputs_1); std::cout << "mobile output: " << std::endl; mobile_output.dump(); } ``` And output from both mobile and jit are ``` {prediction: ([ CPUFloatType{0} ], [ CPUFloatType{0} ])} ``` 3. N1033894 with model inspection, also compare the result between jit and mobile with the dper model. Reviewed By: iseeyuan Differential Revision: D30004716 fbshipit-source-id: cfd30955e66a604af8f9633b1b608feddc13d7d7
Author
Parents
Loading