[remat] Change remat lowering to XLA::Conditional (#2391)
* [remat] Change remat lowering to XLA::Conditional
`jax.remat` creates rematerializing passes that don't have data dependencies on
the actual loss-computing forward pass. This means that the XLA scheduler was
free to schedule the remat forward pass before the loss-computing pass,
defeating the goal of saving accelerator memory with `jax.remat`.
In practice, it sometimes did for my workloads.
This change expresses the lowering of remat_call(f) as:
Conditional(true, inputs, f, inputs, dummy_f).
In the common case of `jax.grad(jax.remat(f))`, the content of the
lowered remat_call are both the forwards & backwards; that is, the
incoming cotangents are part of the args.
Additionally, Conditional (AFAIK) is un-inlineable in the sense that it
doesn't execute until all its inputs (e.g. cotangents!) are available.
Downsides:
- AFAICT, we can no longer interleave computation in/outside the
rematerialized block.
- Potentially, lower performance. I do not observe this in my tests.
* provide no replication info for subcomputation params