pytorch
4bfa51d4 - [jit] fix trace checking reporting divergent names (#37464)

Commit
5 years ago
[jit] fix trace checking reporting divergent names (#37464) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37464 Fixes https://github.com/pytorch/pytorch/issues/23993. There are two fixes here: 1. Previously our name lookup function for the tracer was looking in f.globals for names. For example: ``` sample = torch.ones(1) traced = torch.jit.trace(my_mod, ((sample, sample,),)) # produces a graph with something like # %sample, %sample = prim::TupleUnpack(%input) ``` This is not great if you are, e.g. trace checking, because a non-local bit of interpreter state is affected the graph produced: ``` traced = torch.jit.trace(my_mod, _clone_inputs((sample, sample,),)) # produces a graph with something like # %0, %1 = prim::TupleUnpack(%input) ``` I have removed this functionality, as I don't think it provides huge value. Things that look locally for names will still work, so e.g. inputs, intermediate variables, and the like will be named correctly. 2. Previously, our input cloning for trace checking didn't do a memoized deep copy. So: ``` _clone_inputs((sample, sample, sample)) ``` produces a tuple with three non-aliased tensors. That's wrong! Use copy.deepcopy with a memoization argument to fix this. Test Plan: Imported from OSS Differential Revision: D21297549 Pulled By: suo fbshipit-source-id: 981d5879a4a244520dd68489767129ff357f1497
Author
suo suo
Parents
Loading