pytorch
a8a1c576 - Reset joint graph fake mode earlier, and more comprehensively (#99391)

Commit
1 year ago
Reset joint graph fake mode earlier, and more comprehensively (#99391) This bug was discovered by a stronger assert (which I will be posting in a follow up PR.) The explanation for this change is a bit long and windy, and I am not sure I entirely understand the situation myself. But here's what I think is going on. jansel's joint graph pattern matcher does something fairly unusual: in order to initialize the pattern in question, it (lazily) runs an aot_function invocation in order to trace out what the joint graph of a given pattern looks like (we ought not use aot_function, but we can't really do this until bdhirsh lands AOT Autograd export properly). However, this lazy initialization occurs within the context of a separate compilation, which has its own tracing context, and importantly, fake tensor mode. What we would like, is the pattern matcher lazy initialization fake tensor mode to be unrelated to whatever the ambient fake tensor mode of the graph we were actually compiling. We want these to be independent, because we don't really care what the current compiled graph is; this is a lazy init function, it could have gotten initialized during any compilation, it just happens to be initialized on this one. To prevent us from picking up the ambient fake mode, we have to do two things: we have to remove the tracing context (which stores a fake mode), and we have to also disable the ambiently active fake mode. In https://github.com/pytorch/pytorch/pull/99377 eellison proposed an alternative approach, where we reuse the fake mode. While this probably won't cause any errors, it's morally not the right thing to do, because you'll end up polluting the enclosing fake tensor mode with tensors that have nothing to do with the mode itself. This might fix https://github.com/pytorch/pytorch/issues/99286 but it's also possible that https://github.com/pytorch/pytorch/pull/99320 fixed it already. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/99391 Approved by: https://github.com/bdhirsh
Author
Committer
Parents
Loading