Replace deprecated `jax.tree_*` functions with `jax.tree.*`
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.
PiperOrigin-RevId: 634095385