flax
b528725b - Add tree-mode jit (graph=False) to nnx.jit

Commit
4 days ago
Add tree-mode jit (graph=False) to nnx.jit `nnx.jit` now accepts a `graph` parameter (default `True`). When `graph=False`, Modules are treated as regular JAX pytrees instead of going through the NNX graph protocol. This simpler mode behaves more like JAX as it assumes referential transparency (no sharing), it only propagates updates for Variables, roughly matching the expected behavior of mutable Hijax. Tree-mode removes the need for `graph.update_context` (the most complex part of NNX) and the NNX prefix/Lift APIs such as `StateAxes`, `StateSharding`, and `DiffState`. The implementation is thus much simpler, easier to maintain and optimize. Tree-mode enforces structural constraints via `check_no_aliases` and `apply_variable_updates`: input Variables cannot be returned as outputs, and shared Variable references within inputs are rejected. Limitations of tree-mode (graph=False): - Shared Variable references are not supported (e.g. two sub-modules pointing to the same Linear layer will raise an error). - Input Variables cannot be returned as outputs from the jitted function. - Capturing updates from the backward pass and forwarding state updates to captured objects is not available. The existing graph-mode (graph=True) remains the default and fully backward compatible. Users are not forced to migrate. PiperOrigin-RevId: 868946704
Author
Cristian Garcia
Committer
Parents
Loading