jax
e5bbf3dc - [jax2tf] Fixes a bad interaction between jax2tf.convert, TF, and call_tf.

Commit
1 year ago
[jax2tf] Fixes a bad interaction between jax2tf.convert, TF, and call_tf. Consider the use case when we call_tf a restored saved model that includes parameters (hence functions closing over tf.Variable), and then we jax2tf.convert it with native serialization, under tf.function (or for saving to saved model). The lowering for call_tf in presence of functions with captured inputs requires looking up the tf.Variable and reading its value. This fails with an error that `v.numpy()` is not allowd in graph mode. The fix is to use `tf.init_scope()` to lift out of graph building mode, so that we can read the value of the variables.
Author
Committer
Parents
Loading