pytorch
a896962f - [fx][2/n] Add metadata to placeholders (#102195)

Commit
1 year ago
[fx][2/n] Add metadata to placeholders (#102195) Summary: # Context In TorchRec's train pipeline, we need to fx trace a module to analyze the arguments on the forward call. In order to do this, we need to preserve some sort of meaning with each argument (a key or name of sorts that lets us identify the argument). The issue is, when you use concrete args, internally, fx will unflatten the arg into it's constituents (to locate PHs). Given a function that looks like this: ``` def process(batch: Dict[str, torch.Tensor]): .... symbolic_trace(process, concrete_args: {"batch": {"f1": PH, "f2": PH}}) # function will be rewritten to look like: def process(batch_1, batch_2): # batch_1 -> "f1", batch_2->"f2" ... ``` When you traverse through the nodes of the graph, the names of the argument nodes to the function are batch_1 and batch_2. **This doesn't mean anything to the user who is fx tracing.** There isn't anything indicating that batch_1 corresponds to key "f1" in the batch input. # Solution When fx sees a "PH", it creates a proxy node. The user does not have direct access to proxy creation, but only through the PH structure. Attach a piece of metadata, `ph_key`, to the PH when you set it in the concrete args, it will get passed into proxy + node creation. So when you traverse the graph, this metadata sticks onto the node as an attribute. This way you have a way of tagging that "batch_1" as "f1". Test Plan: added a unit test Reviewed By: dstaay-fb Differential Revision: D44947653 Pull Request resolved: https://github.com/pytorch/pytorch/pull/102195 Approved by: https://github.com/PaliC
Author
Committer
Parents
Loading