flax
c1d00aea - Allow restoring raw state_dict using checkpoints.restore_checkpoint().

Commit
5 years ago
Allow restoring raw state_dict using checkpoints.restore_checkpoint(). Currently, in order to restore a checkpoint, one needs to pass a pytree target with the same structure than the one contained in the checkpoint. This is problematic if one wants to restore only part of the checkpoint, for instance, to initialize a new but slightly different model (e.g. suppose that I just change the number of outputs in the last layer, and want to re-use everything else). By restoring the raw state_dict, one has more liberty to manipulate it in arbitrary ways, without needing to instantiate a pytree with the same structure as the checkpoint. If the raw state_dict is desired, just pass target=None.
Author
Parents
Loading