flax
01460b1c - Add metadata helper transform for adding axis metadata.

Commit
3 years ago
Add metadata helper transform for adding axis metadata. Some users have custom combinatory/transform layers that effectively add an axis to their layer parameters, but that aren't simple vmaps or scans. This adds a metadata helper transform for adding an axis annotation "on the way out" for inits and mutations, and removing that axis "on the way in" during application, exactly the same as the behavior used in scan/vmap - but only touching the metadata of boxed variables, and not modifying any traced variables.
Author
Committer
Parents
Loading