pytorch
eeba9d5a - Preserve node's meta during fx.transformation (#90737)

Commit
1 year ago
Preserve node's meta during fx.transformation (#90737) We wish to preserve node.meta over fx.Transformer transformation and aot_autograd. This will preserve all the meta fields in the original node, including stack_trace, nn_module_stack, val, tensor_meta... Sample Here's a graph produced by Dynamo. ``` class GraphModule(torch.nn.Module): def forward(self, x : torch.Tensor, y : torch.Tensor): # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:35, code: a = torch.cos(x) cos = torch.cos(x); x = None # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:36, code: b = torch.sin(y) sin = torch.sin(y); y = None # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:37, code: return a + b add = cos + sin; cos = sin = None return (add,) x {'creation_timestamp': 0, 'stack_trace': ' File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 45, in forward\n def forward(self, x, y):\n'} y {'creation_timestamp': 0, 'stack_trace': ' File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 45, in forward\n def forward(self, x, y):\n'} cos {'creation_timestamp': 3, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': ' File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 35, in forward\n a = torch.cos(x)\n | File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n return self.block(x, y)\n'} sin {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': ' File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 36, in forward\n b = torch.sin(y)\n | File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n return self.block(x, y)\n'} add {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': ' File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 37, in forward\n return a + b\n | File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n return self.block(x, y)\n'} output {'creation_timestamp': 4} ``` After lowering to aten graph with aot_autograd_simplified() ``` class GraphModule(torch.nn.Module): def forward(self, primals_1: f32[2, 3], primals_2: f32[2, 3]): # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:35, code: a = torch.cos(x) cos: f32[2, 3] = torch.ops.aten.cos.default(primals_1) # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:36, code: b = torch.sin(y) sin: f32[2, 3] = torch.ops.aten.sin.default(primals_2) # File: /scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py:37, code: return a + b add: f32[2, 3] = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None return [add, primals_2, primals_1] primals_1 {'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=True, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})} primals_2 {'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=True, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})} cos {'creation_timestamp': 3, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': ' File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 35, in forward\n a = torch.cos(x)\n | File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n return self.block(x, y)\n', 'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})} sin {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': ' File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 36, in forward\n b = torch.sin(y)\n | File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n return self.block(x, y)\n', 'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})} add {'creation_timestamp': 4, 'nn_module_stack': {'self_block': "<class '__main__.Block'>"}, 'stack_trace': ' File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 37, in forward\n return a + b\n | File "/scratch/bahuang/work/repos/pytorch/temp/dynamo_aotautograd_demo.py", line 46, in forward\n return self.block(x, y)\n', 'val': FakeTensor(FakeTensor(..., device='meta', size=(2, 3)), cpu), 'tensor_meta': TensorMetadata(shape=torch.Size([2, 3]), dtype=torch.float32, requires_grad=False, stride=(3, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})} output {} ``` Notice that output fx node have creation_time_stamp, nn_module_stack and stack_trace copied from the original fx node. val and tensor_meta were latter populated by a subsequent fake_tensor_propagation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90737 Approved by: https://github.com/jerryzh168
Author
Committer
Parents
Loading