fe1cef42 - Don't calculate `scan_xs` and `scan_ys` when batch size is greater than len(input) i.e. num_batches == 0. Fixes https://github.com/jax-ml/jax/issues/29867
Don't calculate `scan_xs` and `scan_ys` when batch size is greater than len(input) i.e. num_batches == 0. Fixes https://github.com/jax-ml/jax/issues/29867
PiperOrigin-RevId: 777767223