Preserve user annotation in graph (#163673)
Summary:
```
import torch
import torch.fx.traceback as fx_traceback
import torch.export
class M(torch.nn.Module):
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
with fx_traceback.annotate({"fdsp_bucket": 0}):
x = x + 1
x = x - 2
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
x = x * 2
x = x / 3
return x
m = M()
with fx_traceback.preserve_node_meta():
ep = torch.export.export(m, (torch.randn(10),))
for node in ep.graph.nodes:
if node.op == "call_function":
print(f"{node.target}, {node.meta.get("custom", {})}")
```
prints
```
aten.add.Tensor, {'pp_stage': 0, 'fdsp_bucket': 0}
aten.sub.Tensor, {'pp_stage': 0}
aten.mul.Tensor, {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1}
aten.div.Tensor, {}
```
TODOs:
- run_decomposition is failing
- Need to test with the new full graph capture + aot_export_joint apis
- Need to make the annotation propagate through autograd engine to reach the bw nodes. Sample impl here: https://github.com/pytorch/pytorch/pull/83558
- Edward want to restrict the key in custom field to be top-level singleton objects only
- also need to take care of metadata merging when passes are fusing nodes
Thanks angelayi for contributing the dynamo fixes.
X-link: https://github.com/pytorch/pytorch/pull/163673
Approved by: https://github.com/albanD, https://github.com/angelayi
Reviewed By: Camyll
Differential Revision: D83298681
fbshipit-source-id: 81c365a97ae00bbed783f736b888309c43351f56