Pass input tensor __dict__ along to placeholder nodes (#94080)
```
import torch
import torch.nn as nn
import torch._dynamo.config
import torch._inductor.config
def pre_attention_state_ops(input, mems, state):
lc_key = state[0]
lc_val = state[1]
bar = []
for i in range(0, 4):
bar2 = []
for j in range(0, 3):
bar2.append(
lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
)
bar.append(bar2)
return bar
mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
state = [
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
]
i = torch.tensor(
[
[0.0313, -0.1487, -0.3846, -0.5321],
[-1.7073, 1.3331, -0.0890, -1.4935],
[-0.8314, -0.1862, -0.5935, 1.5232],
]
)
torch._dynamo.tag(mems, "MEMS")
torch._dynamo.tag(i, "FOO")
torch._dynamo.tag(state[0], "STATE_0")
torch._dynamo.tag(state[1], "HMMM")
exported = torch._dynamo.export(pre_attention_state_ops, i, mems, state)
out_graph = exported[0]
dynamo_result = out_graph(i, mems, state)
nodes = list(out_graph.graph.nodes)
placeholders = [node for node in nodes if node.op == "placeholder"]
for placeholder in placeholders:
if "tags" in placeholder.meta:
print("PLACEHOLDER TAGS?", placeholder.meta["tags"])
```
prints
PLACEHOLDER TAGS? ['STATE_0']
PLACEHOLDER TAGS? ['HMMM']
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94080
Approved by: https://github.com/ezyang, https://github.com/jansel