pytorch
b5a925ff - propagate .meta info when replacing subgraphs in fx (#87255)

Commit
2 years ago
propagate .meta info when replacing subgraphs in fx (#87255) Fixes https://github.com/pytorch/torchdynamo/issues/1708 Our FX subgraph partitioner works by taking all of the original output nodes from a subgraph, and replacing it with a new `call_module` node in the graph. If the original subgraph outputs had fake tensors and other metadata stored in their `.meta` attribute though, then this information was getting lost when we spliced in the subgraph. Losing metadata on an FX graph also seems like an easy trap to fall into, so I'm wondering if there are any better guardrails that we can add. I ended up fixing in this PR by adding an optional kwarg to propagate meta info directly in the `fx.Node.replace_all_uses_with`, just because propagating metadata seems like a pretty core thing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87255 Approved by: https://github.com/wconstab, https://github.com/SherlockNoMad
Author
Committer
Parents
Loading