Bug fix for the "Link bit16 and fp32 parameters in partition" (#5681)
In the function `_link_all_hp_params`
[link](https://github.com/microsoft/DeepSpeed/blob/b33873d234cf6679a3046be9a137682c3469d1fb/deepspeed/runtime/zero/stage_1_and_2.py#L575):
```python
def _link_all_hp_params(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
if self.cpu_offload:
self._get_offload_gradient_dict()
for i, _ in enumerate(self.optimizer.param_groups):
# Link bit16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bit16_groups_flat[i].numel() // dp_world_size
flat_hp_partition = self.single_partition_of_fp32_groups[i]
link_hp_params(lp_param_list=self.bit16_groups[i],
flat_hp_partition=flat_hp_partition,
gradient_dict=self.averaged_gradients,
offload_gradient_dict=self.offload_gradient_dict,
use_offload=self.cpu_offload,
param_group_index=i,
partition_start=partition_id * partition_size,
partition_size=partition_size,
dp_group=self.real_dp_process_group[i])
```
`dp_world_size = dist.get_world_size(group=self.dp_process_group)`
ensures that `dp_world_size` is always the global data parallel word
size.
However, for the MoEs parameter group, the line `partition_size =
self.bit16_groups_flat[i].numel() // dp_world_size` results in an
incorrect `partition_size` when `ep_size > 1` (when expert parallelism
is enabled).
This causes only some of the MoEs parameters to be correctly executed in
`link_hp_params`
[link](https://github.com/microsoft/DeepSpeed/blob/b33873d234cf6679a3046be9a137682c3469d1fb/deepspeed/runtime/zero/stage_1_and_2.py#L568),
while the remaining parameters have `_hp_mapping` set to None.
Consequently, this leads to some parameters not being mapped in
`self._param_slice_mappings = self._create_param_mapping()`, which
directly causes errors in storing the optimizer state file for MoEs
parameters.
To fix this bug, we need to use the correct `dp_world_size` for each
parameter group:
```python
def _link_all_hp_params(self):
if self.cpu_offload:
self._get_offload_gradient_dict()
for i, _ in enumerate(self.optimizer.param_groups):
# Link bit16 and fp32 params in partition
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
partition_size = self.bit16_groups_flat[i].numel() // dist.get_world_size(group=self.real_dp_process_group[i]) # <--
flat_hp_partition = self.single_partition_of_fp32_groups[i]
link_hp_params(lp_param_list=self.bit16_groups[i],
flat_hp_partition=flat_hp_partition,
gradient_dict=self.averaged_gradients,
offload_gradient_dict=self.offload_gradient_dict,
use_offload=self.cpu_offload,
param_group_index=i,
partition_start=partition_id * partition_size,
partition_size=partition_size,
dp_group=self.real_dp_process_group[i])
```
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>