jax
16ffcf39 - [debug_info] Improve debug info for loop constructs.

Commit
209 days ago
[debug_info] Improve debug info for loop constructs. Several loop constructs (`fori_loop`, `while`) are implemented in terms of scan and we transform internally the user-provided loop body to match the API of scan. We want to make sure that the debug info matches the user-provided function, not the manufactured scan body. To achieve this we construct the debug info for the user provided function and we attach it to the attribute `__fun_debug_info__` of the manufactured scan body. Also, refactor for_loop.scan to trace its argument only once. Previously, the `f` argument for `for_loop.scap` was traced twice, once to process the initial values and then again in the scan loop. Not only is this wastefull, but in the second tracing the `arg_names` part of the debug info is messed up. See, e.g., the change in the escaped tracer debug info in the `debug_info_test::test_grad_scan`.
References
Author
Committer
Parents
Loading