jax
55d0f5ef - Add `lower` to `specialize` making it a true `Stage`.

Commit
1 year ago
Add `lower` to `specialize` making it a true `Stage`. So now users can do: ``` specialized = jax.jit(f).specialize(*args) print(specialized.jaxpr, specialized.out_info) lowered = specialized.lower() compiled = lowered.compile() ``` PiperOrigin-RevId: 640737396
Author
Committer
Parents
Loading