[JAX] Fix a small bug if shardings is tuple.
# Details
`jax.tree.map` requests all its arguments to have the same data type.
From ```[None] * len(tensorstore_specs) if global_shapes is None else global_shapes```,
The data type is already decided to be a list. So if we pass `sharding` or `tspecs` as a tuple, it will fail.
Here we add an explicit conversion to a list for sharding and tspecs.
PiperOrigin-RevId: 707576866