jax
c26c77d2 - fix a 'store occupied' error in jax_debug_nans

Commit
5 years ago
fix a 'store occupied' error in jax_debug_nans This code snippet could cause a 'store occupied' error: @jit def f(x): return x + np.nan FLAGS.jax_debug_nans = True f(1) The reason is that in xla._xla_call_impl we would run a linear_util.WrappedFun twice, first via xla._xla_callable and then again directly (i.e. in op-by-op) if we got a nan on the output. Things would work fine if the second execution also raised a nan error, since then the WrappedFun wouldn't complete execution, but if the second execution does not raise an error (as in the above case, because `1 + np.nan` doesn't involve any jax primitive executions) then we'd end up with a StoreOccupied error from running the WrappedFun twice. The fix is just to intentionally allow re-running the WrappedFun, since the whole point of jax_debug_nans is to re-run functions that in normal circumstances we would only want to execute exactly once.
Author
Parents
Loading