Reset joint graph fake mode earlier, and more comprehensively (#99391)
This bug was discovered by a stronger assert (which I will be posting
in a follow up PR.)
The explanation for this change is a bit long and windy, and I am not
sure I entirely understand the situation myself. But here's what I
think is going on.
jansel's joint graph pattern matcher does something fairly unusual:
in order to initialize the pattern in question, it (lazily) runs
an aot_function invocation in order to trace out what the joint graph
of a given pattern looks like (we ought not use aot_function, but we
can't really do this until bdhirsh lands AOT Autograd export properly).
However, this lazy initialization occurs within the context of a
separate compilation, which has its own tracing context, and
importantly, fake tensor mode.
What we would like, is the pattern matcher lazy initialization fake
tensor mode to be unrelated to whatever the ambient fake tensor mode of
the graph we were actually compiling. We want these to be independent,
because we don't really care what the current compiled graph is; this is
a lazy init function, it could have gotten initialized during any
compilation, it just happens to be initialized on this one.
To prevent us from picking up the ambient fake mode, we have to do two
things: we have to remove the tracing context (which stores a fake
mode), and we have to also disable the ambiently active fake mode.
In https://github.com/pytorch/pytorch/pull/99377 eellison proposed an
alternative approach, where we reuse the fake mode. While this probably
won't cause any errors, it's morally not the right thing to do, because
you'll end up polluting the enclosing fake tensor mode with tensors that
have nothing to do with the mode itself.
This might fix https://github.com/pytorch/pytorch/issues/99286
but it's also possible that https://github.com/pytorch/pytorch/pull/99320
fixed it already.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99391
Approved by: https://github.com/bdhirsh