flax
06323f27 - Bumps minimal JAX version from 0.3.16 to 0.3.24 and cleans up some code.

Commit
2 years ago
Bumps minimal JAX version from 0.3.16 to 0.3.24 and cleans up some code. Uses jax.local_devices() as argument to jax.device_put_replicated and jax.device_put_sharded, removing the custom logic for older JAX versions. PiperOrigin-RevId: 506573580
Author
Committer
Parents
Loading