pytorch
d1d5d16d - dynamo: handle straight-line graph breaks for autocast context manager with constant args (#94137)

Commit
1 year ago
dynamo: handle straight-line graph breaks for autocast context manager with constant args (#94137) Fixes https://github.com/pytorch/pytorch/issues/93890 We do the following: 1. fix __init__constructor for `AutocastModeVariable` with exisiting `mode` while copying 2. `resume_execution` is made aware of constant args (`target_values`), by storing said args in `ReenterWith`. To propagate between subgraphs (in straightline code), we also store the constant args in the downstream's `code_options["co_consts"]` if not already. --- Future work: 1. handle instantiating context manager in non-inlineable functions. Simultaneously fix nested grad mode bug. 2. generalize to general `ContextManager`s 3. generalize to variable arguments passed to context manager, with guards around the variable. --- Actually, if we look at the repro: https://github.com/pytorch/pytorch/blob/74592a43d0d33a6c809fdcfc20249e1c93e7216e/test/dynamo/test_repros.py#L1249, we can see that the method in this PR doesn't work for graph breaks in function calls, in particular, in function calls that don't get inlined. Why inlining functions with graph breaks is hard: - When we handle graph breaks, we create a new code object for the remainder of the code. It's hard to imagine doing this when you are inside a function, then we need a frame stack. And we just want to deal with the current frame as a sequence of straight line codes. Why propagating context manager information is hard: - If we do not inline the function, the frame does not contain any information about the parent `block_stack` or `co_consts`. So we cannot store it on local objects like the eval frame. It has to be a global object in the output_graph. --- Anyway, I'm starting to see clearly that dynamo must indeed be optimized for torch use-case. Supporting more general cases tends to run into endless corner-cases and caveats. One direction that I see as viable to handle function calls which have graph breaks and `has_tensor_in_frame` is stick with not inlining them, while installing a global `ContextManagerManager`, similar to the `CleanupManager` (which cleans up global variables). We can know which context managers are active at any given point, so that we can install their setup/teardown code on those functions and their fragments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94137 Approved by: https://github.com/yanboliang
Author
Committer
Parents
Loading