jax
f41f4a87 - Ensure ShapedArray.shape is always a tuple of builtins integers

Commit
6 years ago
Ensure ShapedArray.shape is always a tuple of builtins integers Currently, it can sometimes include elements of type int64, e.g., In [1]: import jax.numpy as jnp In [2]: x = jnp.arange(3) + 1 In [3]: x.shape # looks fine at first glance Out[3]: (3,) In [4]: type(x.shape[0]) # yikes! Out[4]: numpy.int64 This confirms my hypothesis that NumPy's scalar types are the root of all evil.
Author
Committer
Parents
Loading