Add a new stage_p primitive and Trace.stage_value operation.
In a benchmark I found that a significant amount of time was being spent in functions such as `jnp.asarray(x)`, where `x` is either a `Tracer` or a small Python integer.
Before this change, `asarray` ultimately calls `convert_element_type_p.bind()` to lift a value into a trace. To avoid littering the jaxpr with many `convert_element_type` equations, `convert_element_type` has a `convert_elt_type_folding_rule` that allows `DynamicJaxprTracer` to elide them again. However this is awfully indirect! All we wanted to do was lift a value into a `Trace`, and in essence we are paying most of the overhead of constructing and almost immediately destroying a jaxpr equation for each constant.
`stage_p` is a primitive with a custom bind_with_trace rule that rather than using the `process_primitive` path instead calls a new `Trace.stage_value` method. Semantically this should do exactly what the previous path did, but we can do it more efficiently. For example, in the case of `DynamicJaxprTracer` we can simply call `to_jaxpr_tracer` without any of the remaining `process_primitive` machinery.
PiperOrigin-RevId: 902599067