DeepSpeed
91d14527 - Fix the universal checkpoint issue for stage3 when there are multiple subgroups. (#7585)

Commit
72 days ago
Fix the universal checkpoint issue for stage3 when there are multiple subgroups. (#7585) **Describe the bug** When the model is large and there are multiple subgroups, we use ds_to_universal.py, will fail ,the error log are below: ``` *** 1. Extracting ZeRO fragments 0%| | 0/1 [00:03<?, ?it/s] Traceback (most recent call last): File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 21, in <module> main() File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 18, in main ds_to_universal_main(args) File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 523, in main _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir) File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 375, in _extract_zero_shard_files_stage3 _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers) File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 359, in _do_parallel_work results.append(do_work(work)) ^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 167, in extract_zero_shards_stage3 dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 194, in dump_param_fragment state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: start (0) + length (155582464) exceeds dimension size (74499072). ``` **To Reproduce** Steps to reproduce the behavior: 1. Use large model to run, or set sub_group_size to a lower value. Then train and save model 2. Run ds_to_universal.py **The reason** I found that the previous stage3 universal checkpoint implementation did not take subgroups into account. I also found the following problems during debugging. * Unable to handle multiple sub-groups, which will result in data loss * When load_checkpoint is True, then all process will save to same zero model checkpoint file. If multiple processes write at the same time, the file will be corrupted. Occasionally, file corruption was discovered during testing. Relete issue: #7584 --------- Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Author
Parents
Loading