jax
60a8d25b - Support TransformedRefs in higher-order primitives by tree flattening them.

Commit
7 days ago
Support TransformedRefs in higher-order primitives by tree flattening them. This CL changes the tree flattening internally (used by pjit and higher-order primitive) to use the new tracing_registry - like `default_registry` (user-facing), but with additional flattening rules for important internal types we want to "lower" with the pytree mechanism, TransformedRefs being the first example (maybe the only example?). Higher-order primitives should now use either `tracing_registry` directly or `FlatTree` for argument flattening which defaults to the tracing_registry instead of default_registry. This CL does NOT add support for vmap, but does include fixes to jax.export, so that export can use the tracing registry for input/output pytree, to match the behavior of jit. Note, for consistency in jit and export, out tree will also be flattened with tracing_registry, but we do not currently support or want to support TransformedRefs as outputs. PiperOrigin-RevId: 895486134
Author
Parents
Loading