[JAX] Move py_client_gpu into JAX.
This callback functionality is only used by JAX and shipped as part of its CUDA and ROCM GPU plugins. Move it into JAX, as part of a wider move of xla/python pieces that belong to JAX into JAX.
PiperOrigin-RevId: 738426489