jax
904b7486 - [better_errors] Continue adding debug info to Jaxprs (step 3)

Commit
1 year ago
[better_errors] Continue adding debug info to Jaxprs (step 3) This follows after #26078, and #26313, adding `debug_info` to more calls to `lu.wrap_init`. As part of this I have changed the primitives `custom_vjp_call_jaxpr` and `custom_lin` to take the `bwd` parameter as a `lu.WrappedFun`, which carries debug info. Previously, this was a `Callable`, but in almost all cases if was really ` lu.WrappedFun.call_wrapped`.
Author
Committer
Parents
Loading