jax
784a37c2 - Require arguments to Primitive.bind() to pass typeof().

Commit
82 days ago
Require arguments to Primitive.bind() to pass typeof(). We apparently used to enforce this via valid_jaxtype, but that code path was commented out. However as far as I can tell there's nothing stopping us from enabling this again, and that would let us maintain the invariant that all arguments to primitive binds are well-typed. This change is also in preparation for removing other calls to typeof. Also remove a test in lax_test that seems to be identical to one in api_test. PiperOrigin-RevId: 880821716
Author
Parents
Loading