flax
Use `jax.tree_util.tree_map` instead of deprecated `jax.tree_map`.
#3714
Merged

Loading