Avoid passing non-boolean mask to `where` argument of `jax.numpy` reductions. Non-boolean mask inputs have been deprecated for several releases, and will result in an error starting in JAX v0.8.0.
The reason for this restriction is that implicit conversion of non-boolean inputs may lead to confusion: e.g. the user may expect integer inputs to be treated as indices, or float inputs to be treated as weights. But explicitly requiring boolean-typed inputs (as NumPy does), we remove these ambiguities.
PiperOrigin-RevId: 804929133