flax
69b9d44f - Avoid a NumPy bug, triggered by the JAX change

Commit
5 years ago
Avoid a NumPy bug, triggered by the JAX change https://github.com/google/jax/pull/3821. The idea of the JAX change is in part that DeviceArray.__iter__ should return DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is performance: it avoids a host sync. A secondary motivation is type consistency. However, that caused this line of Flax example code to trigger a NumPy bug, discussed in this thread: https://github.com/google/jax/issues/620#issuecomment-484344945 Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a tuple) rather than as an array, leading to an error like "IndexError: too many indices for array". The workaround employed here is to write x[i, ...] instead of x[i], which bypasses the NumPy bug. PiperOrigin-RevId: 345160314
Author
Committer
Parents
Loading