jax
c43f0a0f - [JAX] Require all positional arguments passed to bind to be JAX values.

Commit
78 days ago
[JAX] Require all positional arguments passed to bind to be JAX values. This required two fixes: * don't pass functions as position arguments, instead pass them as a subfuns kwarg * change sparsify to tree-flatten its data structures. PiperOrigin-RevId: 881713266
Author
Parents
Loading