flax
06323f27
- Bumps minimal JAX version from 0.3.16 to 0.3.24 and cleans up some code.
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Commit
View On
GitHub
Commit
2 years ago
Bumps minimal JAX version from 0.3.16 to 0.3.24 and cleans up some code. Uses jax.local_devices() as argument to jax.device_put_replicated and jax.device_put_sharded, removing the custom logic for older JAX versions. PiperOrigin-RevId: 506573580
Author
marcvanzee
Committer
a-googler
Parents
dfb55c4d
Loading