Fix jax.device_put so it doesn't use tree_map for _check_sharding.
This causes it to unnecessarily attempt to unflatten the None return values from _check_sharding into the original tree structure, which is a problem for custom datatypes registered with jax.tree_util that don't accept None values in place of jax arrays.
PiperOrigin-RevId: 570189648