flax
2aaff695 - Fix a bug in axes_scan for custom prngs.

Commit
4 years ago
Fix a bug in axes_scan for custom prngs. axes_scan currently tree_maps pe.PartialVal.unknown across carry and scan args - this in general shouldn't be done as custom array-like entities (such as the KeyArray) can't handle being recreated from partial values during unflattening. Instead we should tree_map avals and only convert them to partial values in the flattened state immediately before tracing. PiperOrigin-RevId: 401296069
Author
Committer
Parents
Loading