Add `lower` to `specialize` making it a true `Stage`.
So now users can do:
```
specialized = jax.jit(f).specialize(*args)
print(specialized.jaxpr, specialized.out_info)
lowered = specialized.lower()
compiled = lowered.compile()
```
PiperOrigin-RevId: 640737396