jax
2653c5b5 - Make sure the `carry` is varying across the same axes as `sorted_arr` when `scan` is called in `searchsorted`. Fixes https://github.com/jax-ml/jax/issues/29881

Loading