Update shard_map(jit) to properly set manual_axes on in_shardings and out_shardings of the nested jit. This avoids a problem where the jit returns {manaual} and then this gets passed to ShardToFull (manual is already considered a full sharding).
PiperOrigin-RevId: 655719254