jax
32bf19ac - Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs.

Commit
1 year ago
Add a temporary fix for spurious debug_nans errors when round-tripping jaxprs. debug_nans is sometimes disabled locally at the traceable level by ops that work with nans internally, like jnp.var. But we don't capture this local change-of-context in the jaxpr. The right thing to do is to add contexts to our jaxpr representation so that we can capture these local context modifications. In the meantime, disabling the checks when we round-trip prevents those ops producing spurious errors. PiperOrigin-RevId: 691494516
Author
Parents
Loading