flax
Use `jnp.stack` instead of `np.stack` in `flax.training.common_utils.stack_forest`
#4991
Merged

Loading