Improve universal checkpoint (#5289)
This PR includes the following improvement regarding universal
checkpoint.
- Restoring step
A universal checkpoint saves the training step count taken from the
engine. In
https://github.com/microsoft/DeepSpeed/pull/5263, we fixed to always set
this count to restore training step count to optimizer's states
per-param (`optimizer_state['state`][param]['step']`) and a param_group.
However, this approach does not restore the optimizer's state and param
groups precisely due to different behaviors of optimizers.
Torch's Adam doesn't make `step` in a param groups and only uses
`optimizer_state['state'][param]['step']`. Apex's fused adam only uses
`step` in a param groups. DeepSpeed's fused adam creates `step` in a
param groups and never updates. It only uses
`optimizer_state['state'][param]['step']`.
Consequently, this leads to discrepancies between the restored and
original states of the optimizer and param groups.
This PR modifies the restoration process to ensure that the step number
in the optimizer's state and param groups matches those in the original
setup, effectively aligning the restored and original optimizer states
and param groups.
- Unit tests of DP size scaling
This PR also adds unit tests to verify universal checkpointing. They run
training with DP, save a checkpoint, and converts in to a universal
checkpoint. Then they load the checkpoint with a different DP size and
validate that parameters and the all-gathered (ZeRO 1/2) optimizer
states match.
- Fix bug of loading with `load_optimizer_states=False`
The loader doesn't load parameters from a universal checkpoint when
`load_optimizer_states=False`.
https://github.com/microsoft/DeepSpeed/pull/5289/commits/c8c0498fe589c20ad830efdecf6a0a28f38fb7ae
fixes this issue.