jax
fe1cef42 - Don't calculate `scan_xs` and `scan_ys` when batch size is greater than len(input) i.e. num_batches == 0. Fixes https://github.com/jax-ml/jax/issues/29867

Commit
217 days ago
Don't calculate `scan_xs` and `scan_ys` when batch size is greater than len(input) i.e. num_batches == 0. Fixes https://github.com/jax-ml/jax/issues/29867 PiperOrigin-RevId: 777767223
Author
Parents
Loading