flax
3ae65afe - Add tree-mode support for nnx.recursive_map, nnx.view, and nnx.view_info

Commit
12 days ago
Add tree-mode support for nnx.recursive_map, nnx.view, and nnx.view_info Add a `graph` parameter to `nnx.recursive_map`, `nnx.view`, and `nnx.view_info`. When `graph=False`, these functions use JAX's native pytree traversal instead of Flax's graph protocol. Cycles and shared Variable/Ref references are detected and raise errors in tree mode. Added parametrized tests for both graph and tree modes. PiperOrigin-RevId: 873225517
Author
Cristian Garcia
Committer
Parents
Loading