jax
f499352c - Add `tree_util.is_tree_node` to allow checking whether a type is a registered JAX pytree.

Commit
43 days ago
Add `tree_util.is_tree_node` to allow checking whether a type is a registered JAX pytree. This CL exposes `is_node` function of the C++ pytree registry so that user won't need to rely on querying the private python registry. Eventually we want to migrate people away and delete the python registry. PiperOrigin-RevId: 896061794
Author
Parents
Loading