flax
c7fa10a0 - Migrate Flax dataclass to the newest JAX pytree keypath API.

Commit
2 years ago
Migrate Flax dataclass to the newest JAX pytree keypath API. * `flax.struct.dataclass` already registered via `register_keypath`, and this one only changes it to latest API. * `flax.core.FrozenDict` was registered so that flattening a frozen dict should be the same as flattening the underlying dict. This makes its serialization backward-compatible. PiperOrigin-RevId: 515096454
Author
Committer
Parents
Loading