jax
f8aa563d - make jax.numpy.array(3) give 0D array, not scalar

Commit
6 years ago
make jax.numpy.array(3) give 0D array, not scalar the mechanism is to use lax.reshape (which was already there) and avoid the optimization that skipped actually calling reshape_p.bind fixes #121
Author
Parents
Loading