flax
7ed82a49 - Plumb through the prevent_cse kwarg for jax remat.

Commit
4 years ago
Plumb through the prevent_cse kwarg for jax remat. Jax recently exposed an option to turn off the common-subexpression elimination foil device in remat. This is needed for e.g. scan(remat(...)). This PR plumbs that kwarg through the Flax lifted remat/checkpoint transforms.
Author
Committer
Parents
Loading