Support TransformedRefs in higher-order primitives by tree flattening them.
This CL changes the tree flattening internally (used by pjit and higher-order
primitive) to use the new tracing_registry - like `default_registry`
(user-facing), but with additional flattening rules for important internal
types we want to "lower" with the pytree mechanism, TransformedRefs being the
first example (maybe the only example?).
Higher-order primitives should now use either `tracing_registry` directly or `FlatTree` for argument flattening which
defaults to the tracing_registry instead of default_registry.
This CL does NOT add support for vmap, but does include fixes to jax.export, so
that export can use the tracing registry for input/output pytree, to match the
behavior of jit.
Note, for consistency in jit and export, out tree will also be flattened with
tracing_registry, but we do not currently support or want to support
TransformedRefs as outputs.
PiperOrigin-RevId: 895486134