Avoid passing `concrete` argument to `jax.remat`
This argument has had no effect since JAX v0.3.17, aside from raising `NotImplementedError` if it is set to `True`. It will be deprecated in JAX v0.8.2 and eventually removed (https://github.com/jax-ml/jax/pull/33674).
Flax should probably deprecate this argument from its own `remat` wrappers, but I'll leave that up to the team.
#jax-fixit
PiperOrigin-RevId: 840353856