Fix bug in all_leaves when is_leaf is specified
This issue occurs when some of the leaves have custom `__eq__` methods defined on them, which either result in errors when compared to some other types (see http://cl/753579906), or result in return values that cannot have their truthiness evaluated, e.g.:
```
import jax.tree_util as jtu
import numpy as np
jtu.all_leaves(
[[np.asarray([1, 2])]],
is_leaf=lambda x: jtu.all_leaves([x]),
)
```
```
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
```
This fix avoids equality issues by using the `is` operator instead of `==`, and introduces tests for the case where `is_leaf` is provided.
PiperOrigin-RevId: 753684035