jax
8dc3715a - Fall back to the pure JAX implementation of tridiagonal_solve for m<=2 on GPU:

Commit
4 days ago
Fall back to the pure JAX implementation of tridiagonal_solve for m<=2 on GPU: Fixes https://github.com/jax-ml/jax/issues/32487 PiperOrigin-RevId: 892421765
Author
Parents
Loading