flax
934d53cd - Add tree-mode support to nnx.scan

Commit
83 days ago
Add tree-mode support to nnx.scan Adds `TreeScanFn` dataclass and `graph` parameter to `nnx.scan`, enabling tree-mode operation when `graph=False`. Tree-mode enforces default `in_axes=(Carry, 0)` and `out_axes=(Carry, 0)`, checks carry Variable identity preservation across iterations, and performs alias checks on scan outputs. Also parametrizes existing scan tests and adds new tree-mode specific tests. PiperOrigin-RevId: 873044806
Author
Cristian Garcia
Committer
Parents
Loading