DeepSpeed
c56a4b9e - Improve universal checkpoint (#5289)

Comment changes are shownComment changes are hidden
Commit
1 year ago
Improve universal checkpoint (#5289) This PR includes the following improvement regarding universal checkpoint. - Restoring step A universal checkpoint saves the training step count taken from the engine. In https://github.com/microsoft/DeepSpeed/pull/5263, we fixed to always set this count to restore training step count to optimizer's states per-param (`optimizer_state['state`][param]['step']`) and a param_group. However, this approach does not restore the optimizer's state and param groups precisely due to different behaviors of optimizers. Torch's Adam doesn't make `step` in a param groups and only uses `optimizer_state['state'][param]['step']`. Apex's fused adam only uses `step` in a param groups. DeepSpeed's fused adam creates `step` in a param groups and never updates. It only uses `optimizer_state['state'][param]['step']`. Consequently, this leads to discrepancies between the restored and original states of the optimizer and param groups. This PR modifies the restoration process to ensure that the step number in the optimizer's state and param groups matches those in the original setup, effectively aligning the restored and original optimizer states and param groups. - Unit tests of DP size scaling This PR also adds unit tests to verify universal checkpointing. They run training with DP, save a checkpoint, and converts in to a universal checkpoint. Then they load the checkpoint with a different DP size and validate that parameters and the all-gathered (ZeRO 1/2) optimizer states match. - Fix bug of loading with `load_optimizer_states=False` The loader doesn't load parameters from a universal checkpoint when `load_optimizer_states=False`. https://github.com/microsoft/DeepSpeed/pull/5289/commits/c8c0498fe589c20ad830efdecf6a0a28f38fb7ae fixes this issue.
Author
Parents
  • deepspeed
    • File
      __init__.py
    • checkpoint
      • File
        constants.py
      • File
        ds_to_universal.py
      • File
        universal_checkpoint.py
      • File
        zero_checkpoint.py
    • runtime
      • File
        __init__.py
      • File
        base_optimizer.py
      • File
        bf16_optimizer.py
      • File
        engine.py
      • fp16
        • File
          fused_optimizer.py
        • File
          unfused_optimizer.py
      • zero
        • File
          stage3.py
        • File
          stage_1_and_2.py
  • tests/unit/checkpoint
    • File
      common.py
    • File
      test_universal_checkpoint.py
    • File
      test_zero_optimizer.py