Make changes to shard_map to prepare for setting `varying_axes_in_types` to True.
The main changes here are:
* Don't take the `_efficient_transpose_rewrite` transformation path anymore. In other words, `RewriteTrace` and all the rewriting machinery is dead.
* Wherever internally we were setting `check_rep=False` explicitly like `_prim_applier`, `_match`, `_unmatch`, `_shard_map_partial_eval`, `_shard_map_partial_eval_custom` (for remat), don't do that anymore. Instead set `check_rep` to the `check_rep` value so that it can be True if the user hasn't passed `check_rep=False`.
* Introduce an internal `_check_rep` context manager and set it wherever `extend_axis_env_nd` is used so that if `check_rep=False` on `shard_map`, JAX will set `vma` in `ShapedArray` to empty `frozenset`.
* Because of point (2), if `check_rep=True`, we can't set `in_specs` and `out_specs` of shmap internally to all manual axes of the mesh on the 0th dim. It needs to be whatever the argument was varying on.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 745276474