Allow restoring raw state_dict using checkpoints.restore_checkpoint().
Currently, in order to restore a checkpoint, one needs to pass a pytree
target with the same structure than the one contained in the checkpoint.
This is problematic if one wants to restore only part of the checkpoint,
for instance, to initialize a new but slightly different model (e.g.
suppose that I just change the number of outputs in the last layer, and
want to re-use everything else). By restoring the raw state_dict, one
has more liberty to manipulate it in arbitrary ways, without needing to
instantiate a pytree with the same structure as the checkpoint.
If the raw state_dict is desired, just pass target=None.