pytorch
2b12bfce - [dynamo] Skip frame when graph break in a loop (#88857)

Commit
2 years ago
[dynamo] Skip frame when graph break in a loop (#88857) This fixes excessing recompilation issue in tacotron2 but has few caveats - https://github.com/pytorch/torchdynamo/issues/330 For tacotron2, the repro is something like this ~~~ def inner(x): return torch.sin(x) def fn(x): for _ in range(100): inner(x) torch._dynamo.graph_break() return x ~~~ The problem here is that Dynamo has guards on the TUPLE_ITERATOR_LEN whenever a graph break happens. Therefore, we keep on recompiling. This PR checks if there is a backedge (helps with while loop) in presence of a graph break. If there is, Dynamo skips processing this frame. Therefore, Dynamo gets called when inner is called, and we compile only once. Note that, if there was no graph break, we will unroll the original loop, and see one graph with 100 sin operations (just as before, so no changes there). The caveat is - We are skipping the frame, so if we have something like this ~~~ def fn(x): for _ in range(100): # 1000s of lines of PyTorch code torch._dynamo.graph_break() return x ~~~ Dynamo will skip processing this frame, and might miss on the optimization. Completely open for suggestions. Happy to re-implement if there is a better way to handle this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88857 Approved by: https://github.com/jansel, https://github.com/yanboliang
Author
Committer
Parents
Loading