flax
4887f7dc - Avoid passing non-boolean mask to `where` argument of `jax.numpy` reductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0.

Commit
202 days ago
Avoid passing non-boolean mask to `where` argument of `jax.numpy` reductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0. The reason for this restriction is that implicit conversion of non-boolean inputs may lead to confusion: e.g. the user may expect integer inputs to be treated as indices, or float inputs to be treated as weights. But explicitly requiring boolean-typed inputs (as NumPy does), we remove these ambiguities. PiperOrigin-RevId: 804929133
Author
Jake VanderPlas
Committer
Parents
Loading