jax
620bf430 - [remat] Change remat lowering to XLA::Conditional (#2391)

Commit
6 years ago
[remat] Change remat lowering to XLA::Conditional (#2391) * [remat] Change remat lowering to XLA::Conditional `jax.remat` creates rematerializing passes that don't have data dependencies on the actual loss-computing forward pass. This means that the XLA scheduler was free to schedule the remat forward pass before the loss-computing pass, defeating the goal of saving accelerator memory with `jax.remat`. In practice, it sometimes did for my workloads. This change expresses the lowering of remat_call(f) as: Conditional(true, inputs, f, inputs, dummy_f). In the common case of `jax.grad(jax.remat(f))`, the content of the lowered remat_call are both the forwards & backwards; that is, the incoming cotangents are part of the args. Additionally, Conditional (AFAIK) is un-inlineable in the sense that it doesn't execute until all its inputs (e.g. cotangents!) are available. Downsides: - AFAICT, we can no longer interleave computation in/outside the rematerialized block. - Potentially, lower performance. I do not observe this in my tests. * provide no replication info for subcomputation params
Author
Parents
Loading