Remove calls to deprecated function jax.lax.tie_in
This has been a no-op since jax v0.2.0, and passes the second argument through unchanged. `tie_in` will be deprecated as of jax v0.4.24; see https://github.com/google/jax/pull/19413
PiperOrigin-RevId: 600515954