Fix that ds_secondary_tensor may be dirty when loading the model or zero checkpoint for zero++. (#7707)
`ds_secondary_tensor` may be dirty during model loading or zero
checkpointing for zero++.
* 1 Loading model
My task is transformers SFT. In the transformers code, initialization is
done using code like the following:
```
with deepspeed.zero.Init():
model = xxx
```
After this, `param` is already a ds tensor, meaning both `ds_tensor` and
`ds_secondary_tensor` exist. Then `load_model` is called to reload the
model.
```
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
```
In `GatheredParameters.__exit__`, `params[0].partition` is called, and
`has_been_updated` is set to `True`, indicating that data updates are
needed. However, `_partition_param_sec` did not pass `has_been_updated`.
This results in `ds_secondary_tensor` being dirty.
* 2 Loading zero checkpoint
The zero checkpoint is loaded into `fp16_partitioned_groups_flat`,
meaning `param.ds_tensor` has been updated. However, the data in
`param.ds_secondary_tensor` has not been updated. But the next
`allgather` will use the dirty `param.ds_secondary_tensor`.
A dirty `ds_secondary_tensor` can lead to abnormal loss. After calling
`invalidate_secondary_tensor` in `_post_step`, the loss returns to
normal. This is why loss anomaly only occurs during beginning steps.
Relate issue: #7606
Signed-off-by: zhengchenyu <zhengchenyu16@163.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>