jax
8f2d72eb - Simplify handling of non-linear equations in backward_pass and fix remat (#3162)

Commit
5 years ago
Simplify handling of non-linear equations in backward_pass and fix remat (#3162) Previously, `backward_pass` has been generalized to be able to handle non-linear computation in the body, but it could easily get confused into doing unnecessary work only to throw it away later. Additionally, it treated any call primitive embedded inside remat like remat itself, which is obviously wrong. This patch fixes both of those issues and simplifies a bunch of the code at the same time. `backward_pass` now has an invariant that it only deals with jaxprs containing linear equations alone, and becomes a simple transposing interpreter again. **Background on JVP vs linearization** Ok, so why does this change actually fix the problem? It is important to understand that JVP and linearization transforms are actually two different things, even though we often identify them as one. Both take in a function of type `a -> b`, but their ranges are different! JVP returns a function of type `(a, T a) -> (b, T b)` while linearization returns `a -> (b, T a --o T b)`. Note that the second type carries more information, because we get a guarantee that (1) `b` does not depend on `T a` and (2) the dependence of `T b` on `T a` is linear. The reason why we usually treat them as equivalent, is that they can be shown to be "isomorphic". If we take the output of linearization, we can make it a JVP-like function using the following combinator: ```haskell jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta) ``` More importantly for JAX, which doesn't have a linearization interpreter, if we assume (1) and (2), linearization can be recovered in terms of jvp as well: ```haskell linearize f = \a -> let fjvp = jvp f in partial_eval fjvp (Known a) Unknown ``` That is, if we have a mathematically correct JVP, then linearization is simply partial evaluation with all primal values marked as known, and all tangents treated as yet unknown values. One important performance consideration is that for forward-mode AD we really want to use the JVP formulation, which can interleave the computation of primals and tangents, instead of sequencing them and increasing the memory cost. On the other hand, transposition (necessary for VJPs!) can only be applied to linear functions, and so it can't possibly work on the output of JVP. It really can only be apply to the second output of the linearization transform. Hence, we really care about both, but can we avoid having two very similar implementations of (approximately) the same thing? It seems that the answer is yes, because of the equivalence outlined above! **If all this is so nice, then what's the problem?** The problem is, of course, remat. Partial eval is able to thread the known/unknown information correctly through regular call primitives, but mind you, remat is no regular call primitive! Once we enter remat, we are no longer interested in treating _anything_ like a known value. After all, our goal here is to record an accurate trace of everything that has happened in the body of a remat, including the primal (known!) computation. This however presents a challenge for implementing linearization in terms of JVP, because inside the body of remat we break the assumption that known/unknown corresponds to the primal/tangent distinction. Its body, instead of representing the second output of linearization simply contains the traced JVP code now... One way to fix it would be to implement a proper linearization pass that would track the distinciton between primal and tangent information while still allowing to stage out code for primals. @mattjj and I have even started hacking together an implementation for that. I've been trying to convince @mattjj that there is no other way to go about it, but I couldn't really convince him that this is the case. Then, once I wanted to write a semi-formal proof I could no longer even convince myself! Turns out that there is an alternative solution! What this patch does is, it stops caring about the output of the `linearize` function (defined as JVP + partial eval, as discussed above) to be a good linearization. It still is if you don't use remats in your code, but it still breaks miserably once you do. However, as long as all the complications are contained solely in the `call_jaxpr` embedded inside a remat, we still have a chance to fix them! This is because the transposition interpreter never reaches into those bodies directly, but rather asks the call primitive to transpose itself. Now, how do you transpose remat? We can't just reuse the code used for regular call primitives (this is what happens now BTW), because unlike for them, the `call_jaxpr` doesn't represent a linear function! But it's not completely useless either --- it contains the traced JVP code. So, how do we get from there to a linear function? Partial eval! And if you think about it, it is exactly what we wanted --- we end up evaluating all the primal code in the body once again, while only staging out the tangent computation, to be passed into the transposing interpreter again. Fin.
Author
Parents
Loading