[fx][2/n] Add metadata to placeholders (#102195)
Summary:
# Context
In TorchRec's train pipeline, we need to fx trace a module to analyze the arguments on the forward call. In order to do this, we need to preserve some sort of meaning with each argument (a key or name of sorts that lets us identify the argument).
The issue is, when you use concrete args, internally, fx will unflatten the arg into it's constituents (to locate PHs).
Given a function that looks like this:
```
def process(batch: Dict[str, torch.Tensor]):
....
symbolic_trace(process, concrete_args: {"batch": {"f1": PH, "f2": PH}})
# function will be rewritten to look like:
def process(batch_1, batch_2): # batch_1 -> "f1", batch_2->"f2"
...
```
When you traverse through the nodes of the graph, the names of the argument nodes to the function are batch_1 and batch_2. **This doesn't mean anything to the user who is fx tracing.** There isn't anything indicating that batch_1 corresponds to key "f1" in the batch input.
# Solution
When fx sees a "PH", it creates a proxy node.
The user does not have direct access to proxy creation, but only through the PH structure.
Attach a piece of metadata, `ph_key`, to the PH when you set it in the concrete args, it will get passed into proxy + node creation. So when you traverse the graph, this metadata sticks onto the node as an attribute. This way you have a way of tagging that "batch_1" as "f1".
Test Plan: added a unit test
Reviewed By: dstaay-fb
Differential Revision: D44947653
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102195
Approved by: https://github.com/PaliC