jax
5ff55080 - [mutable-arrays] add discharge_state2 for ClosedJaxprs, with caching

Commit
127 days ago
[mutable-arrays] add discharge_state2 for ClosedJaxprs, with caching This helper function follows our usual pattern of caching on ClosedJaxprs, which handles jaxprs and const lists together by id, and also keeps weak references to both. We can probably update most or all callers to use discharge_state2 instead of discharge_state, at which point we can delete the latter. This PR only updates the jit lowering rule, jit discharge rule, and scan discharge rule to use the new helper.
Author
Parents
Loading