flax
Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: `ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch` due to split_rngs being replicated.
#5189
Merged

Commits
  • Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: `ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch` due to split_rngs being replicated.
    a-googler committed 74 days ago
Loading