[jax2tf] Fixes a bad interaction between jax2tf.convert, TF, and call_tf.
Consider the use case when we call_tf a restored saved model that
includes parameters (hence functions closing over tf.Variable), and then
we jax2tf.convert it with native serialization, under tf.function (or
for saving to saved model).
The lowering for call_tf in presence of functions with captured inputs
requires looking up the tf.Variable and reading its value. This fails
with an error that `v.numpy()` is not allowd in graph mode. The fix
is to use `tf.init_scope()` to lift out of graph building mode, so that
we can read the value of the variables.