fix a 'store occupied' error in jax_debug_nans
This code snippet could cause a 'store occupied' error:
@jit
def f(x):
return x + np.nan
FLAGS.jax_debug_nans = True
f(1)
The reason is that in xla._xla_call_impl we would run a
linear_util.WrappedFun twice, first via xla._xla_callable and then again
directly (i.e. in op-by-op) if we got a nan on the output. Things would
work fine if the second execution also raised a nan error, since then
the WrappedFun wouldn't complete execution, but if the second execution
does not raise an error (as in the above case, because `1 + np.nan`
doesn't involve any jax primitive executions) then we'd end up with a
StoreOccupied error from running the WrappedFun twice.
The fix is just to intentionally allow re-running the WrappedFun, since
the whole point of jax_debug_nans is to re-run functions that in normal
circumstances we would only want to execute exactly once.