pytorch
a12e92d8 - Support nn.Module forward hooks in torchdynamo (#92125)

Commit
1 year ago
Support nn.Module forward hooks in torchdynamo (#92125) Tweak dynamo behavior in 2 places when calling nn.Modules, to route the call to __call__ instead of .forward(), since __call__ is the codepath that eager users hit and will dispatch to hooks correctly. (1) inside NNModuleVariable.call_function, which covers the common case of calling a module from code dynamo is already tracing (2) at the OptimizedModule layer, which is the entrypoint into a top-level nn.Module dynamo is about to compile This exposes a new bug: NNModuleVariable used to special-case calling module.forward() (which is a method) as a UserFunctionVariable with an extra 'self' arg. After tracing into module.__call__, there is no longer a special case for the eventual call into .forward, and it gets wrapped in a UserDefinedObjectVariable following standard behavior of ._wrap(). UDOV can't be called, so this broke some tests. - Fix: add a new special case in _wrap() that treats methods as a UserDefinedMethod instead of UserDefinedObjectVariable. Now, the forward method can be called. Also, fix NNModuleVar.call_method routing forward back to __call__ Pull Request resolved: https://github.com/pytorch/pytorch/pull/92125 Approved by: https://github.com/ezyang, https://github.com/jansel, https://github.com/voznesenskym
Author
Committer
Parents
Loading