jax
f7faaa81
- Fix lax.switch where unmapped arg was not being broadcasted on the correct `axis_data.explicit_mesh_axis` under vmap.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
264 days ago
Fix lax.switch where unmapped arg was not being broadcasted on the correct `axis_data.explicit_mesh_axis` under vmap. Fixes: https://github.com/jax-ml/jax/issues/29637 PiperOrigin-RevId: 774795809
References
#29652 - Fix lax.switch where unmapped arg was not being broadcasted on the correct `axis_data.explicit_mesh_axis` under vmap.
Author
yashk2810
Committer
Google-ML-Automation
Parents
30ba5c9c
Loading