pytorch
bac26155 - [JIT] Allow freezing modules that contain mutable interfaces (#86039)

Commit
2 years ago
[JIT] Allow freezing modules that contain mutable interfaces (#86039) This PR allows freezing modules like the one below: ```python # Ex. 1 @torch.jit.interface class ModuleInterface(torch.nn.Module): def forward(self, inp: torch.Tensor) -> torch.Tensor: pass class ImplementsInterface(torch.nn.Module): def __init__(self): super(ImplementsInterface, self).__init__() self.sum = torch.zeros((2, 2)) def forward(self, inp: torch.Tensor) -> torch.Tensor: self.sum += inp.relu() # this makes the interface-implementing module mutable # and previously this would prevent freezing return self.sum class WrapperModule(torch.nn.Module): impl: ModuleInterface def __init__(self): super().__init__() self.impl = ImplementsInterface() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.impl.forward(x) ``` Previously during freezing, we handle interfaces as shown below: 1. we inline interfaces in any preserved method graphs 2. during `cleanupFrozenModule`, we try to simplify the module data structure (<- this part is unrelated to freezing so far). During this step, if we found that a interface type was mutable, we'd error out; because of the possibility of a module that _swaps out the value of an interface-typed attribute at runtime_. Below is an example of a module that swaps out the value of an interface-typed attribute at runtime: ```python # Ex. 2 class MyBadModule(torch.nn.Module): impl: MyInterface option1: IfaceImpl1 option2: IfaceImpl2 .... def forward(self, x): if x > 0: self.impl = self.option1 else: self.impl = self.option2 .... ``` ^ this type of situation cannot be supported by freezing (or at least would be difficult to do correctly) because it greatly complicates the details of handling types and simplifying the module data structure. But we can still support the first example without _too_ much work: 1. inline the interface code as before 2. check to see if we have any setattrs on interface types; if so, error out 3. otherwise, replace the type of the interface types with the concrete type implementation 4. continue simplifying the module data structure as if we never had any interfaces. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86039 Approved by: https://github.com/eellison
Author
Committer
Parents
Loading