flax
83a5b3ba
- Bumps minimal JAX version from 0.3.16 to 0.3.24 and cleans up some code.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
3 years ago
Bumps minimal JAX version from 0.3.16 to 0.3.24 and cleans up some code. Uses jax.local_devices() as argument to jax.device_put_replicated and jax.device_put_sharded, removing the custom logic for older JAX versions. PiperOrigin-RevId: 506667707
References
#2840 - Bumps minimal JAX version from 0.3.16 to 0.3.24 and cleans up some code.
Author
marcvanzee
Committer
a-googler
Parents
1b5b504b
Loading