flax
719217b6 - Let Flax-Orbax to not port the shape of `target` arrays when they port the `target` shardings.

Commit
2 years ago
Let Flax-Orbax to not port the shape of `target` arrays when they port the `target` shardings. This allow people to continue using Flax checkpointing API with target pytrees of desired sharding but smaller shapes, avoiding memory burdens. No impact if user is using native Orbax. PiperOrigin-RevId: 551289243
Author
Committer
Parents
Loading