DeepSpeed
a687d327 - fix: Propagate `strip_tensor_paddings` (#7426)

Commit
151 days ago
fix: Propagate `strip_tensor_paddings` (#7426) Trying to use the `DeepSpeed/deepspeed/checkpoints/ds_to_universal.py`, I encountered: ```python Traceback (most recent call last): File "/opt/aurora/24.347.0/frameworks/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker r = call_item.fn(*call_item.args, **call_item.kwargs) File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/deps/DeepSpeed/deepspeed/checkpoint/ds_to_universal.py", line 114, in extract_zero_shards sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/venvs/aurora/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/site-packages/deepspeed/checkpoint/deepspeed_checkpoint.py", line 124, in get_zero_checkpoint_state return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index, File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/venvs/aurora/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/site-packages/deepspeed/checkpoint/zero_checkpoint.py", line 62, in get_state_for_rank self._strip_tensor_paddings(sd) File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/venvs/aurora/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/site-packages/deepspeed/checkpoint/zero_checkpoint.py", line 110, in _strip_tensor_paddings group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone() RuntimeError: narrow(): length must be non-negative. ``` (see full traceback[^traceback] below) The issue is, there's no way to propagate the `strip_tensor_paddings` argument from the [`DeepSpeedCheckpoint.get_zero_checkpoint_state(...)`](https://github.com/deepspeedai/DeepSpeed/blob/affee605e47c9befd21c4c1445e11fd29d295201/deepspeed/checkpoint/deepspeed_checkpoint.py#L123) method through to the [`ZeroCheckpoint.get_state_for_rank(...)` method](https://github.com/deepspeedai/DeepSpeed/blob/affee605e47c9befd21c4c1445e11fd29d295201/deepspeed/checkpoint/zero_checkpoint.py#L53) (which accepts it as an argument) since it doesn't accept it. This PR adds this additional `strip_tensor_paddings` argument (default `True`) in the `DeepSpeedCheckpoint.get_zero_checkpoint_state` method, and passes it through to the `self.zero_checkpoint.get_state_for_rank(..., strip_tensor_paddings=strip_tensor_paddings)`, as shown below: ```diff - def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index) -> dict: + def get_zero_checkpoint_state(self, pp_index, tp_index, dp_index, strip_tensor_paddings: bool = True) -> dict: return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index, - keys_to_ignore=[PARAM_SHAPES]) + keys_to_ignore=[PARAM_SHAPES], + strip_tensor_paddings=strip_tensor_paddings) ``` [^traceback]: Full traceback: <details closed><summary>[Full Traceback]:</summary> ```bash #[🐍 aurora_nre_models_frameworks-2025.0.0](👻 aurora_nre_models_frameworks-2025.0.0) #[/f/A/C/f/p/a/Megatron-DeepSpeed][🌱 saforem2/fix-formatting][✓] #[07/12/25 @ 16:07:12][x4209c2s4b0n0] ; ckpt_dir=checkpoints/ws768_ds_stage1_nl32_hs4096_mb1_seq4096_gb3072_sp1_pp1_tp1_bf16_optadamw_lr_lwf_flash ; gs=$(cat "${ckpt_dir}/latest_checkpointed_iteration.txt") && echo "global step: ${gs}" && python3 deps/DeepSpeed/deepspeed/checkpoint/ds_to_universal.py --input_folder"${ckpt_dir}/global_step${gs}" --output_folder "${ckpt_dir}/global_step${gs}_universal" --keep_temp_folder global step: 158945 [W712 16:07:17.966425018 OperatorEntry.cpp:155] Warning: Warning only once for all operators, other operators may also be overridden. Overriding a previously registered kernel for the same operator and the same dispatch key operator: aten::_cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () registered at /build/pytorch/build/aten/src/ATen/RegisterSchema.cpp:6 dispatch key: XPU previous kernel: registered at /build/pytorch/build/aten/src/ATen/RegisterCPU.cpp:30476 new kernel: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterXPU.cpp:2971 (function operator()) /opt/aurora/24.347.0/frameworks/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/site-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py:6: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81. import pkg_resources AttributeError: 'MessageFactory' object has no attribute 'GetPrototype' AttributeError: 'MessageFactory' object has no attribute 'GetPrototype' AttributeError: 'MessageFactory' object has no attribute 'GetPrototype' AttributeError: 'MessageFactory' object has no attribute 'GetPrototype' AttributeError: 'MessageFactory' object has no attribute 'GetPrototype' [2025-07-12 16:07:27,740] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to xpu (auto detect) [2025-07-12 16:07:29,078] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False args = Namespace(input_folder='checkpoints/ws768_ds_stage1_nl32_hs4096_mb1_seq4096_gb3072_sp1_pp1_tp1_bf16_optadamw_lr_lwf_flash/global_step158945', output_folder='checkpoints/ws768_ds_stage1_nl32_hs4096_mb1_seq4096_gb3072_sp1_pp1_tp1_bf16_optadamw_lr_lwf_flash/global_step158945_universal', num_extract_workers=4, num_merge_workers=2, keep_temp_folder=True, strict=True, inject_missing_state=False) Convert DeepSpeed Checkpoint to Universal Checkpoint Converting DeepSpeed checkpoint in checkpoints/ws768_ds_stage1_nl32_hs4096_mb1_seq4096_gb3072_sp1_pp1_tp1_bf16_optadamw_lr_lwf_flash/global_step158945 to Universal checkpoint in checkpoints/ws768_ds_stage1_nl32_hs4096_mb1_seq4096_gb3072_sp1_pp1_tp1_bf16_optadamw_lr_lwf_flash/global_step158945_universal /lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/megatron/core/tensor_parallel/layers.py:290: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead. def forward( /lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/megatron/core/tensor_parallel/layers.py:334: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead. def backward(ctx, grad_output): [2025-07-12 16:07:39,134079][I][ezpz/__init__:264:ezpz] Setting logging level to 'INFO' on 'RANK == 0' [2025-07-12 16:07:39,136376][I][ezpz/__init__:265:ezpz] Setting logging level to 'CRITICAL' on all others 'RANK != 0' *** 1. Extracting ZeRO fragments 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋| 767/768 [01:29<00:00, 8.53it/s] concurrent.futures.process._RemoteTraceback: """ Traceback (most recent call last): File "/opt/aurora/24.347.0/frameworks/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker r = call_item.fn(*call_item.args, **call_item.kwargs) File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/deps/DeepSpeed/deepspeed/checkpoint/ds_to_universal.py", line 114, in extract_zero_shards sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=pp_index, tp_index=tp_index, dp_index=dp_index) File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/venvs/aurora/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/site-packages/deepspeed/checkpoint/deepspeed_checkpoint.py", line 124, in get_zero_checkpoint_state return self.zero_checkpoint.get_state_for_rank(pp_index=pp_index, File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/venvs/aurora/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/site-packages/deepspeed/checkpoint/zero_checkpoint.py", line 62, in get_state_for_rank self._strip_tensor_paddings(sd) File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/venvs/aurora/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/site-packages/deepspeed/checkpoint/zero_checkpoint.py", line 110, in _strip_tensor_paddings group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone() RuntimeError: narrow(): length must be non-negative. """ The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/deps/DeepSpeed/deepspeed/checkpoint/ds_to_universal.py", line 549, in <module> main(args) File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/deps/DeepSpeed/deepspeed/checkpoint/ds_to_universal.py", line 499, in main _extract_zero_shard_files(args, ds_checkpoint, temp_dir) File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/deps/DeepSpeed/deepspeed/checkpoint/ds_to_universal.py", line 370, in _extract_zero_shard_files _do_parallel_work(do_work, _3d_range_list, args.num_extract_workers) File "/lus/flare/projects/AuroraGPT/CPT-AuroraGPT-v0/foremans/projects/argonne-lcf/Megatron-DeepSpeed/deps/DeepSpeed/deepspeed/checkpoint/ds_to_universal.py", line 354, in _do_parallel_work results.append(f.result()) File "/opt/aurora/24.347.0/frameworks/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/concurrent/futures/_base.py", line 451, in result return self.__get_result() File "/opt/aurora/24.347.0/frameworks/aurora_nre_models_frameworks-2025.0.0/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result raise self._exception RuntimeError: narrow(): length must be non-negative. [1] 144664 exit 1 python3 deps/DeepSpeed/deepspeed/checkpoint/ds_to_universal.py --input_folder took: 0h:02m:08s ``` </details> Signed-off-by: Sam Foreman <saforem2@gmail.com>
Author
Parents
Loading