jax
a5644edb - Defer to unrecognized types in arithmetic (#1942)

Commit
6 years ago
Defer to unrecognized types in arithmetic (#1942) This is useful for building higher level array libraries around JAX, because it makes it possible to override operations like `jax_array + other`. I think I covered all the array types that JAX should be able to handle: - Python builtin numbers int, float and complex - NumPy scalars - NumPy arrays - JAX array types and tracers Did I miss anything? Maybe bfloat16 scalars?
Author
Parents
Loading