jax
b59c0e48 - Initial commit for `smap` i.e. shard_map1D. The signature is `smap(f, in_axes, out_axes, axis_name)`. This change does NOT make the API public.

Commit
328 days ago
Initial commit for `smap` i.e. shard_map1D. The signature is `smap(f, in_axes, out_axes, axis_name)`. This change does NOT make the API public. The API semantics are as follows: * `smap` only allows going into `Manual` mode one mesh axes at a time via the `axis_name` argument. * mesh needs to be present in the context via `use_mesh` or `set_mesh`. * If in_axes or out_axes contains `None`, it means that the input(s) is **replicated**. This is similar to `vmap` where `None` means unmapped input. * If the context mesh is in full explicit mode, `in_axes` can be inferred from the arguments. But how do we tell `smap` to do that? We **can't** use `None` because `None` means replicated in `smap`. So we introduce a singleton called `Infer` which when passed to `smap`, will tell it to infer the in_axes (in_specs) from the arguments! For example: `smap(f, in_axes=Infer, out_axes=0, axis_name='x')`. You always have the option of specifying `in_axes` and not infer even in full explicit mode :) PiperOrigin-RevId: 753695446
Author
Parents
Loading