[JIT] Allow freezing modules that contain mutable interfaces (#86039)
This PR allows freezing modules like the one below:
```python
# Ex. 1
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, inp: torch.Tensor) -> torch.Tensor:
pass
class ImplementsInterface(torch.nn.Module):
def __init__(self):
super(ImplementsInterface, self).__init__()
self.sum = torch.zeros((2, 2))
def forward(self, inp: torch.Tensor) -> torch.Tensor:
self.sum += inp.relu() # this makes the interface-implementing module mutable
# and previously this would prevent freezing
return self.sum
class WrapperModule(torch.nn.Module):
impl: ModuleInterface
def __init__(self):
super().__init__()
self.impl = ImplementsInterface()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.impl.forward(x)
```
Previously during freezing, we handle interfaces as shown below:
1. we inline interfaces in any preserved method graphs
2. during `cleanupFrozenModule`, we try to simplify the module data structure (<- this part is unrelated to freezing so far). During this step, if we found that a interface type was mutable, we'd error out; because of the possibility of a module that _swaps out the value of an interface-typed attribute at runtime_.
Below is an example of a module that swaps out the value of an interface-typed attribute at runtime:
```python
# Ex. 2
class MyBadModule(torch.nn.Module):
impl: MyInterface
option1: IfaceImpl1
option2: IfaceImpl2
....
def forward(self, x):
if x > 0:
self.impl = self.option1
else:
self.impl = self.option2
....
```
^ this type of situation cannot be supported by freezing (or at least would be difficult to do correctly) because it greatly complicates the details of handling types and simplifying the module data structure.
But we can still support the first example without _too_ much work:
1. inline the interface code as before
2. check to see if we have any setattrs on interface types; if so, error out
3. otherwise, replace the type of the interface types with the concrete type implementation
4. continue simplifying the module data structure as if we never had any interfaces.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86039
Approved by: https://github.com/eellison