jax
fe359a15 - [jax] bound traceback of pallas_call to within the kernel.

Commit
90 days ago
[jax] bound traceback of pallas_call to within the kernel. Jax's lowering cache for pallas call might reuse a previous lowering result for a different jit resulting in a mlir module containing the wrong debug info. This does not cause correctness issues, but changes the hlo fingerprints which may cause mismatches when there are multiple parallel AOT compilations submitted to a thread pool. By bounding the stack trace to under the call site, cached stacktraces should match. PiperOrigin-RevId: 857343868
Author
Parents
Loading