flax
bbbbdcc9 - [JAX] Replace uses of `isinstance(x, jax.numpy.ndarray)` with `isinstance(x, (numpy.ndarray, jax.numpy.ndarray))` where they were leading to test failures.

Commit
4 years ago
[JAX] Replace uses of `isinstance(x, jax.numpy.ndarray)` with `isinstance(x, (numpy.ndarray, jax.numpy.ndarray))` where they were leading to test failures. An upcoming change to JAX will make `isinstance(x, jax.numpy.ndarray)` return true if and only if `x` is an instance of a JAX array type. Previously `isinstance(x, jax.numpy.ndarray)` also returned true for classic NumPy's `numpy.ndarray` objects as well. After the upcoming change, it will return false for `numpy.ndarray` objects. This change updates users of JAX who were depending on the current behavior of the `isinstance` check to instead explicitly check for `numpy.ndarray` instances as well. These changes should have no effect on using `jax.numpy.ndarray` as a type annotation. That does little and never has, although it is possible that may change in the future. This change is strictly about what `jax.numpy.ndarray` means to runtime `isinstance` checks. PiperOrigin-RevId: 397124334
Author
Committer
Parents
Loading