DeepSpeed
c069ceb3 - Fix that ds_secondary_tensor may be dirty when loading the model or zero checkpoint for zero++. (#7707)

Commit
49 days ago
Fix that ds_secondary_tensor may be dirty when loading the model or zero checkpoint for zero++. (#7707) `ds_secondary_tensor` may be dirty during model loading or zero checkpointing for zero++. * 1 Loading model My task is transformers SFT. In the transformers code, initialization is done using code like the following: ``` with deepspeed.zero.Init(): model = xxx ``` After this, `param` is already a ds tensor, meaning both `ds_tensor` and `ds_secondary_tensor` exist. Then `load_model` is called to reload the model. ``` with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): if torch.distributed.get_rank() == 0: module._load_from_state_dict(*args) ``` In `GatheredParameters.__exit__`, `params[0].partition` is called, and `has_been_updated` is set to `True`, indicating that data updates are needed. However, `_partition_param_sec` did not pass `has_been_updated`. This results in `ds_secondary_tensor` being dirty. * 2 Loading zero checkpoint The zero checkpoint is loaded into `fp16_partitioned_groups_flat`, meaning `param.ds_tensor` has been updated. However, the data in `param.ds_secondary_tensor` has not been updated. But the next `allgather` will use the dirty `param.ds_secondary_tensor`. A dirty `ds_secondary_tensor` can lead to abnormal loss. After calling `invalidate_secondary_tensor` in `_post_step`, the loss returns to normal. This is why loss anomaly only occurs during beginning steps. Relate issue: #7606 Signed-off-by: zhengchenyu <zhengchenyu16@163.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Author
Parents
Loading