transformers
e6f221c8 - [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)

Commit
3 years ago
[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361) * [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util
Parents
Loading