jax
c93f283f - [shmap] `shard_map` and `axis_index_groups` for standard collectives

Commit
1 year ago
[shmap] `shard_map` and `axis_index_groups` for standard collectives These collectives have unreplicated outputs and the exception can simply be removed. Unit tests are added for `lax.all_gather`, `lax.all_to_all`, `lax.psum_scatter`. Fixes #19709. PiperOrigin-RevId: 607104947
Author
Committer
Parents
Loading