Add tree-mode jit (graph=False) to nnx.jit
`nnx.jit` now accepts a `graph` parameter (default `True`). When `graph=False`,
Modules are treated as regular JAX pytrees instead of going through the NNX
graph protocol. This simpler mode behaves more like JAX as it assumes
referential transparency (no sharing), it only propagates updates for
Variables, roughly matching the expected behavior of mutable Hijax.
Tree-mode removes the need for `graph.update_context` (the most complex
part of NNX) and the NNX prefix/Lift APIs such as `StateAxes`,
`StateSharding`, and `DiffState`. The implementation is thus much simpler,
easier to maintain and optimize.
Tree-mode enforces structural constraints via `check_no_aliases` and
`apply_variable_updates`: input Variables cannot be returned as outputs,
and shared Variable references within inputs are rejected.
Limitations of tree-mode (graph=False):
- Shared Variable references are not supported (e.g. two sub-modules
pointing to the same Linear layer will raise an error).
- Input Variables cannot be returned as outputs from the jitted function.
- Capturing updates from the backward pass and forwarding state updates
to captured objects is not available.
The existing graph-mode (graph=True) remains the default and fully
backward compatible. Users are not forced to migrate.
PiperOrigin-RevId: 868946704