Increment jax requirement to fix incompatibility
When using `jax==0.2.20` with `flax==0.3.5`, `flax.jax_utils.replicate` fails with:
`AttributeError: 'jaxlib.tpu_client_extension.PyTpuBuffer' object has no attribute 'dtype'`
Incrementing minimum `jax` version to `jax==0.2.21` fixes this problem.