Update tridiagonal solve kernels on GPU to properly use the FFI.
This fixes https://github.com/jax-ml/jax/issues/28544 by using the batched algorithms directly when possible. It also adds complex dtype and batch partitioning support to tridiagonal solves on GPU.
PiperOrigin-RevId: 758129745