jax
61703690 - Add a direct HLO lowering of remat_p that doesn't call eval_jaxpr.

Commit
1 year ago
Add a direct HLO lowering of remat_p that doesn't call eval_jaxpr. This turns out to be faster, not least because we don't need to use the tracing machinery again. PiperOrigin-RevId: 647462045
Author
Committer
Parents
Loading