[JAX] Require all positional arguments passed to bind to be JAX values.
This required two fixes:
* don't pass functions as position arguments, instead pass them as a subfuns kwarg
* change sparsify to tree-flatten its data structures.
PiperOrigin-RevId: 881713266