Add a simple form of partial evaluation for while_loop. (#2497)
The issue that I wanted to fix was that when running grad(while_loop),
the error was a cryptic assertion failure (that all primals are known
after linearization, in ad.py:linearize). I could not figure out
how to detect before that assertion that we are doing a reverse AD
for while_loop. So, I implemented a simple form of partial evaluation,
to allow the primals after linearization to be known, so that the
code proceeds and can then fail gracefully when trying to transpose the
while.
This is not a proper implementation of partial evaluation. The known
outputs are computed early, properly. But the unknown outputs
are computed by a *whole* computation of, including the known
parts.
Fixes issue: #2129