Plumb through the prevent_cse kwarg for jax remat.
Jax recently exposed an option to turn off the common-subexpression
elimination foil device in remat. This is needed for e.g. scan(remat(...)).
This PR plumbs that kwarg through the Flax lifted remat/checkpoint
transforms.