jax
4768e003 - Fix bug in all_leaves when is_leaf is specified

Commit
274 days ago
Fix bug in all_leaves when is_leaf is specified This issue occurs when some of the leaves have custom `__eq__` methods defined on them, which either result in errors when compared to some other types (see http://cl/753579906), or result in return values that cannot have their truthiness evaluated, e.g.: ``` import jax.tree_util as jtu import numpy as np jtu.all_leaves( [[np.asarray([1, 2])]], is_leaf=lambda x: jtu.all_leaves([x]), ) ``` ``` ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() ``` This fix avoids equality issues by using the `is` operator instead of `==`, and introduces tests for the case where `is_leaf` is provided. PiperOrigin-RevId: 753684035
Author
Luke Tsekouras
Parents
Loading