Fix a bug in axes_scan for custom prngs.
axes_scan currently tree_maps pe.PartialVal.unknown across carry
and scan args - this in general shouldn't be done as custom array-like
entities (such as the KeyArray) can't handle being recreated from partial
values during unflattening. Instead we should tree_map avals and only
convert them to partial values in the flattened state immediately before
tracing.
PiperOrigin-RevId: 401296069