If `out_shardings=None` then it means that XLA will choose the sharding for that output. But before this change under a `use_mesh(mesh)` context, we were converting `None` to replicated.
This change fixes that and let's XLA choose the sharding so that the behavior with and without the `use_mesh` context is the same.
But due to backwards compatibility and historical reasons, we can only apply this to `jax.jit` and not `with mesh: pjit(f, out_shardings=None)` calls because the pjit considers None as replicated under a mesh context (:sad face:).
PiperOrigin-RevId: 778920842