Fix for change to `jax.Array.__contains__`
In https://github.com/jax-ml/jax/pull/36655 JAX updates the implementation of `jax.Array.__contains__`, which makes `val in array` more efficient.
As part of this, JAX will become more strict in what types it allows; e.g. `string in array` will raise an error rather than returning `False`.
This update future-proofs code in preparation for this change.
PiperOrigin-RevId: 897791321