flax
9906f212 - Increment jax requirement to fix incompatibility

Commit
4 years ago
Increment jax requirement to fix incompatibility When using `jax==0.2.20` with `flax==0.3.5`, `flax.jax_utils.replicate` fails with: `AttributeError: 'jaxlib.tpu_client_extension.PyTpuBuffer' object has no attribute 'dtype'` Incrementing minimum `jax` version to `jax==0.2.21` fixes this problem.
Author
Parents
Loading