jax
1629c6c7
- Make `jax.jit` work with vmap(..., spmd_axis_name) when there is no mesh context manager.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
2 years ago
Make `jax.jit` work with vmap(..., spmd_axis_name) when there is no mesh context manager. This will only work if the input Array's sharding is a NamedSharding Fixes https://github.com/google/jax/issues/15886 PiperOrigin-RevId: 529758233
Author
yashk2810
Committer
a-googler
Parents
d992080b
Loading