Defer to unrecognized types in arithmetic (#1942)
This is useful for building higher level array libraries around JAX, because it
makes it possible to override operations like `jax_array + other`.
I think I covered all the array types that JAX should be able to handle:
- Python builtin numbers int, float and complex
- NumPy scalars
- NumPy arrays
- JAX array types and tracers
Did I miss anything? Maybe bfloat16 scalars?