Fix the universal checkpoint issue for stage3 when there are multiple subgroups. (#7585)
**Describe the bug**
When the model is large and there are multiple subgroups, we use
ds_to_universal.py, will fail ,the error log are below:
```
*** 1. Extracting ZeRO fragments
0%| | 0/1 [00:03<?, ?it/s]
Traceback (most recent call last):
File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 21, in <module>
main()
File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 18, in main
ds_to_universal_main(args)
File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 523, in main
_extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir)
File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 375, in _extract_zero_shard_files_stage3
_do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers)
File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 359, in _do_parallel_work
results.append(do_work(work))
^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 167, in extract_zero_shards_stage3
dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset,
File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 194, in dump_param_fragment
state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: start (0) + length (155582464) exceeds dimension size (74499072).
```
**To Reproduce**
Steps to reproduce the behavior:
1. Use large model to run, or set sub_group_size to a lower value. Then
train and save model
2. Run ds_to_universal.py
**The reason**
I found that the previous stage3 universal checkpoint implementation did
not take subgroups into account. I also found the following problems
during debugging.
* Unable to handle multiple sub-groups, which will result in data loss
* When load_checkpoint is True, then all process will save to same zero
model checkpoint file. If multiple processes write at the same time, the
file will be corrupted. Occasionally, file corruption was discovered
during testing.
Relete issue: #7584
---------
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>