jax
f7faaa81 - Fix lax.switch where unmapped arg was not being broadcasted on the correct `axis_data.explicit_mesh_axis` under vmap.

Commit
264 days ago
Fix lax.switch where unmapped arg was not being broadcasted on the correct `axis_data.explicit_mesh_axis` under vmap. Fixes: https://github.com/jax-ml/jax/issues/29637 PiperOrigin-RevId: 774795809
Author
Parents
Loading