flax
7551bfc3 - Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: `ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch` due to split_rngs being replicated.

Commit
4 days ago
Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: `ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch` due to split_rngs being replicated. PiperOrigin-RevId: 857821732
Author
Committer
Parents
Loading