Initial commit for `smap` i.e. shard_map1D. The signature is `smap(f, in_axes, out_axes, axis_name)`. This change does NOT make the API public.
The API semantics are as follows:
* `smap` only allows going into `Manual` mode one mesh axes at a time via the `axis_name` argument.
* mesh needs to be present in the context via `use_mesh` or `set_mesh`.
* If in_axes or out_axes contains `None`, it means that the input(s) is **replicated**. This is similar to `vmap` where `None` means unmapped input.
* If the context mesh is in full explicit mode, `in_axes` can be inferred from the arguments. But how do we tell `smap` to do that? We **can't** use `None` because `None` means replicated in `smap`.
So we introduce a singleton called `Infer` which when passed to `smap`, will tell it to infer the in_axes (in_specs) from the arguments! For example: `smap(f, in_axes=Infer, out_axes=0, axis_name='x')`.
You always have the option of specifying `in_axes` and not infer even in full explicit mode :)
PiperOrigin-RevId: 753695446