jax
6e23c14f - jax.debug.callback now passes arguments as jax.Arrays

Commit
1 year ago
jax.debug.callback now passes arguments as jax.Arrays Prior to this change the behavior in eager and under jax.jit was inconsistent >>> (lambda *args: jax.debug.callback(print, *args))([42]) [42] >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42]) [array(42, dtype=int32)] It was also inconsistent with other callback APIs, which cast the arguments to jax.Arrays. Closes #20627. PiperOrigin-RevId: 626461904
Author
Committer
Parents
Loading