jax
7d716b83 - Add a simple form of partial evaluation for while_loop. (#2497)

Commit
5 years ago
Add a simple form of partial evaluation for while_loop. (#2497) The issue that I wanted to fix was that when running grad(while_loop), the error was a cryptic assertion failure (that all primals are known after linearization, in ad.py:linearize). I could not figure out how to detect before that assertion that we are doing a reverse AD for while_loop. So, I implemented a simple form of partial evaluation, to allow the primals after linearization to be known, so that the code proceeds and can then fail gracefully when trying to transpose the while. This is not a proper implementation of partial evaluation. The known outputs are computed early, properly. But the unknown outputs are computed by a *whole* computation of, including the known parts. Fixes issue: #2129
Author
Parents
Loading