pytorch
c8334859 - [dynamo][api] Better support of torch.nn.Module (#88629)

Commit
2 years ago
[dynamo][api] Better support of torch.nn.Module (#88629) This is an API change, so please review carefully. With this PR, torchdynamo returns an `OptimizedModule` class object, a subclass of `torch.nn.Module`, when asked to optimize a `nn.Module` object. Most of the methods are redirected to the original `nn.Module`, which is installed as `_mod` in the `OptimizedModule`. This is helpful for many cases ``` mod = MockModule() opt_mod = torch._dynamo.optimize()(mod) print(opt_mod) # Works opt_mod = opt_mod.to(device="cuda") print(opt_mod) # Works opt_mod(input) # Triggers recompile if necessary, earlier we were shedding the TorchDynamo wrapper opt_mod.parameters() # Refers to the original module ``` Topics unclear to me * I have overridden many methods to raise NotImplementedError. A careful review of those will be good. * hooks * For the optimized forward, should we call torchdynamo optimization on `__call__` or `forward` * What else to test Pull Request resolved: https://github.com/pytorch/pytorch/pull/88629 Approved by: https://github.com/Chillee, https://github.com/jansel, https://github.com/msaroufim
Author
Committer
Parents
Loading