pytorch
c9adc4c3 - [Dynamo] De-dup graph inputs (#98775)

Commit
1 year ago
[Dynamo] De-dup graph inputs (#98775) ### Overview This PR de-duplicates graph inputs in TorchDynamo, using the `Source` as the unique identifier for each input. This closes https://github.com/pytorch/pytorch/issues/98743 and https://github.com/pytorch/pytorch/issues/98625. ### Details `VariableBuilder.wrap_tensor()` should return a `VariableTracker` for the passed-in `value: Tensor`. If `value` is duplicated, we should avoid calling `OutputGraph.create_graph_input()` and `OutputGraph.add_grapharg()`. - Note that `create_graph_input()` and `add_grapharg()` are not 1:1. For a constant source and either `wrap_sym()` or `wrap_unspecialized_primitive()`, TorchDynamo still calls `create_graph_input()` but not `add_grapharg()`. - Note that `create_graph_input()` should be called before constructing the corresponding `VariableTracker`. TorchDynamo needs the `fx.Proxy` object to pass to `wrap_fx_proxy()`. In this PR, the `OutputGraph` saves an additional mapping `input_source_to_var` from each graph input's `Source` to its `VariableTracker`, which works because `Source` is now hashable. This mapping should be updated each time `create_graph_input()` is called. However, since we must construct the `VariableTracker` after `create_graph_input()` returns, we must have a separate call to the `OutputGraph` to update the mapping. If anyone has any suggestion on how to coalesce this logic and avoid having to remember to update `input_source_to_var` for each `create_graph_input()`, I would love to hear it. <details> <summary> Alternate Approach</summary> Initially, I tried having TorchDynamo construct a new but equivalent `VariableTracker` for the duplicated tensor. However, I abandoned this approach after hitting an assertion in `def wrap_fx_proxy_cls()` due to `"example_value"` already being in the proxy node's metadata because we were reusing the primary tensor's `Proxy` object. Reusing the exact `VariableTracker` also seems less error-prone instead of requiring constructing a new but identical `VariableTracker`. </details> ### Testing #### Global Variable Test ``` import torch @torch.compile() def f(): return x + x x = torch.randn(3) f() ``` Before: ``` ====== Forward graph 0 ====== <eval_with_key>.6 class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[3], arg1_1: f32[3]): # File: /data/users/ezyang/b/pytorch/ff.py:5, code: return x + x add: f32[3] = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return (add,) ``` After (only `arg0_1` and no more `arg1_1`): ``` ====== Forward graph 0 ====== <eval_with_key>.4 class <lambda>(torch.nn.Module): def forward(self, arg0_1: f32[3]): # File: dynamo/test_dup_global.py:8, code: return x + x add: f32[3] = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None return (add,) ``` #### FSDP Test Before we error on ``` File "/.../pytorch/torch/_guards.py", line 244, in __post_init__ assert self.input_source_a != self.input_source_b ``` and now there is no error. --- The rename from `name_to_input` to `input_name_to_proxy` is not part of the core logic change and is a remnant from initial attempts. I can undo it later if desired, but I also feel that the new name is more informative. It also fixes the type annotation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98775 Approved by: https://github.com/ezyang, https://github.com/voznesenskym
Author
Committer
Parents
Loading