Commit
1 year ago
[nnx] fix vmap ## Changes * passes `state_axes` as a list to `vmapped_fn` to avoid a JAX compare error on the dictionary PiperOrigin-RevId: 643267090
Author
Cristian Garcia
Committer
Parents
Loading