Commit
131 days ago
fix RNN Implements Mapping for `StateAxes` and uses `StateAxes` in place of `dict` to fix RNN, this avoids some JAX pytree errors when scanning attributes for data in `nnx.Pytree`. PiperOrigin-RevId: 800653246
Author
Cristian Garcia
Committer
Parents
Loading