jax
f1ae2166 - Added argument check to all primitives. (#3197)

Commit
5 years ago
Added argument check to all primitives. (#3197) * Added argument check to all primitives. The issue that inspired this is that `lax.tie_in` is easy to misuse if the first argument is not a JAX type, then it silently disappears. This means that `lax.tie_in((x, x), const)` is the same as `const` even though `x` is a tracer. This error would be caught previously if core.skip_checks == False because then `bind` checks its arguments. I have essentially added an unconditional argument check to `bind`. In case this is considered too inefficient, we can add argument checking to individual primivites, e.g., tie_in. For most primitives if a non-JAX array is passed, the `impl` rule would fire and `numpy` would report the error somehow, perhaps. * Merged find_top_trace with check_args This was previously merged as #2948 but reverted awaiting the fixes in some user code.
Author
Parents
Loading