pytorch
e44586a7 - Pass input tensor __dict__ along to placeholder nodes (#94080)

Commit
1 year ago
Pass input tensor __dict__ along to placeholder nodes (#94080) ``` import torch import torch.nn as nn import torch._dynamo.config import torch._inductor.config def pre_attention_state_ops(input, mems, state): lc_key = state[0] lc_val = state[1] bar = [] for i in range(0, 4): bar2 = [] for j in range(0, 3): bar2.append( lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) ) bar.append(bar2) return bar mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) state = [ torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), ] i = torch.tensor( [ [0.0313, -0.1487, -0.3846, -0.5321], [-1.7073, 1.3331, -0.0890, -1.4935], [-0.8314, -0.1862, -0.5935, 1.5232], ] ) torch._dynamo.tag(mems, "MEMS") torch._dynamo.tag(i, "FOO") torch._dynamo.tag(state[0], "STATE_0") torch._dynamo.tag(state[1], "HMMM") exported = torch._dynamo.export(pre_attention_state_ops, i, mems, state) out_graph = exported[0] dynamo_result = out_graph(i, mems, state) nodes = list(out_graph.graph.nodes) placeholders = [node for node in nodes if node.op == "placeholder"] for placeholder in placeholders: if "tags" in placeholder.meta: print("PLACEHOLDER TAGS?", placeholder.meta["tags"]) ``` prints PLACEHOLDER TAGS? ['STATE_0'] PLACEHOLDER TAGS? ['HMMM'] Pull Request resolved: https://github.com/pytorch/pytorch/pull/94080 Approved by: https://github.com/ezyang, https://github.com/jansel
Author
Committer
Parents
Loading