pytorch
6eab5e88 - Graph-break on allowed modules if they have hooks (#97184)

Commit
1 year ago
Graph-break on allowed modules if they have hooks (#97184) Allowed modules are stuck into dynamo's fx graph as call_module nodes, without dynamo doing any tracing of the module. This means during AOT trace time, hooks will fire during tracing when the call_module is executed, but the hooks themselves will disappear after that and not be present in the compiled program. (worse, if they performed any tensor operations, those would get traced so you could end up with part of the hook's functionality). To circumvent this, there are two options for 'allowed modules' with hooks. 1) don't treat them as 'allowed' - trace into them 2) graph-break, so the module is no longer part of the dynamo trace at all (1) will fail for users that opted into allowed modules becuase they know their module has problems being traced by dynamo. (2) causes graph breaks on common modules such as nn.Linear, just because they are marked as 'allowed'. It would help matters if we could differentiate between types of allowed modules (A) allowed to avoid overheads - used for common ops like nn.Linear (B) allowed to avoid dynamo graphbreaks caused by unsupported code Ideally, we'd use method (1) for group (A) and (2) for (B). For now, graph-break on all cases of allowed modules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97184 Approved by: https://github.com/jansel
Author
Committer
Parents
Loading