DeepSpeed
077bff55 - deepcompile: Fix backward graph recompilation due to unbalanced forward/backward visits (#7980)

Commit
18 days ago
deepcompile: Fix backward graph recompilation due to unbalanced forward/backward visits (#7980) In PyTorch AOT Autograd, having tensors requiring grad in inputs doesn't guarantee backward graph compilation. If no output requires grad and no input requiring grad is mutated, aot_autograd skips backward compilation (see [1]). DeepCompile previously required backward compilation for every forward graph which required grad, but relied solely on the existence of require_grad tensors. This mismatch caused unbalanced forward/backward visits, leaving graphs unvisited in `frames_needing_bwd`. The patched FunctionMeta then remained effective during backward execution, raising KeyError when removing the (already-removed) frame IDs from the `frames_needing_bwd` set. A reproduction can be found at [2]. Simply put a guard on the set removal operation is insufficient. The backward graph is still recompiled on each iteration, severely impacting performance. Instead of duplicating how AOT Autograd determines whether to compile the backward graph, use the fact that a joint graph requires a backward pass if and only if it is partitioned into a forward and a backward module. The frame IDs of partitioned graphs are collected in the patched partition functions and then used to determine `needs_backward` in the forward compile function. `backend_fn` is not a proper place for the second step since autograd creates fw/bw compile functions before partitioning a joint graph. References [1] https://github.com/pytorch/pytorch/blob/aea31e0c306e2315bf6d84255e0dde7adf09762a/torch/_functorch/aot_autograd.py#L618 [2] https://gist.github.com/eternalNight/96d6bc60e2bf566fda1300154d0e89dc Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
Author
Parents
Loading