jax
2e548e5b - AOT: remove explicit lojax stage (`Fallen`) and extend `Traced` instead

Commit
16 days ago
AOT: remove explicit lojax stage (`Fallen`) and extend `Traced` instead We can simply hang lojax information off of a `Traced` instance (as `Traced.lojax`), without making it a `Stage` of its own. From an external point of view, it seems OK for our various jaxpr levels to remain bundled under the tracing stage of AOT. This change preserves some degree of explicit public control over when the work of producing lojax is carried out, by computing it lazily. Note that we had no tests exercising `fall()` or `Fallen` anyway, suggesting that we have been treating these as internal symbols rather than a tested public API. PiperOrigin-RevId: 840433294
Author
Parents
Loading