flax
Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: `ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch` due to split_rngs being replicated.
#5189
Merged

Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: `ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch` due to split_rngs being replicated. #5189

copybara-service merged 1 commit into main from test_856253490
copybara-service
copybara-service copybara-service force pushed from e296680c to ed741903 10 days ago
copybara-service copybara-service changed the title Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Shard split rngs in lifted vmap if spmd_axis_name is given and applied to vmapped axis. Without this, jax.vmap gives an error: `ValueError: Mapped away dimension of inputs passed to vmap should be sharded the same. Got inconsistent axis specs: None vs batch` due to split_rngs being replicated. 10 days ago
copybara-service copybara-service force pushed from ed741903 to a955c5f9 7 days ago
jossb-iso Shard split rngs in lifted vmap if spmd_axis_name is given and applie…
7551bfc3
copybara-service copybara-service force pushed from a955c5f9 to 7551bfc3 7 days ago
copybara-service copybara-service merged 7551bfc3 into main 7 days ago
copybara-service copybara-service deleted the test_856253490 branch 7 days ago

Login to write a write a comment.

Login via GitHub

Reviewers
No reviews
Assignees
No one assigned
Labels
Milestone