benchmark
b654281f - Preserve user annotation in graph (#163673)

Commit
136 days ago
Preserve user annotation in graph (#163673) Summary: ``` import torch import torch.fx.traceback as fx_traceback import torch.export class M(torch.nn.Module): def forward(self, x): with fx_traceback.annotate({"pp_stage": 0}): with fx_traceback.annotate({"fdsp_bucket": 0}): x = x + 1 x = x - 2 with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}): x = x * 2 x = x / 3 return x m = M() with fx_traceback.preserve_node_meta(): ep = torch.export.export(m, (torch.randn(10),)) for node in ep.graph.nodes: if node.op == "call_function": print(f"{node.target}, {node.meta.get("custom", {})}") ``` prints ``` aten.add.Tensor, {'pp_stage': 0, 'fdsp_bucket': 0} aten.sub.Tensor, {'pp_stage': 0} aten.mul.Tensor, {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1} aten.div.Tensor, {} ``` TODOs: - run_decomposition is failing - Need to test with the new full graph capture + aot_export_joint apis - Need to make the annotation propagate through autograd engine to reach the bw nodes. Sample impl here: https://github.com/pytorch/pytorch/pull/83558 - Edward want to restrict the key in custom field to be top-level singleton objects only - also need to take care of metadata merging when passes are fusing nodes Thanks angelayi for contributing the dynamo fixes. X-link: https://github.com/pytorch/pytorch/pull/163673 Approved by: https://github.com/albanD, https://github.com/angelayi Reviewed By: Camyll Differential Revision: D83298681 fbshipit-source-id: 81c365a97ae00bbed783f736b888309c43351f56
Author
Parents
Loading