Fix all-gather duplicate params and wrong dtype (#7462)
The following assertion error arises when torch autocast is enabled.
[rank3]: File
"/opt/deepspeed/deepspeed/runtime/zero/partitioned_param_coordinator.py",
line 337, in fetch_sub_module
[rank3]:
self.__inflight_param_registry.pop(param).wait(handle_dependency=not
fast_fetch)
[rank3]: File
"/opt/deepspeed/deepspeed/runtime/zero/partition_parameters.py", line
787, in wait
[rank3]: handle.wait(handle_dependency)
[rank3]: File "/opt/deepspeed/deepspeed/utils/nvtx.py", line 20, in
wrapped_fn
[rank3]: ret_val = func(*args, **kwargs)
[rank3]: File
"/opt/deepspeed/deepspeed/runtime/zero/partition_parameters.py", line
750, in wait
[rank3]: assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected
param {param.ds_summary()} to be inflight"
[rank3]: AssertionError: expected param {'id': 685, 'status':
'AVAILABLE', 'numel': 131334144, 'ds_numel': 131334144, 'shape': (32064,
4096), 'ds_shape': (32064, 4096), 'requires_grad': True, 'grad_shape':
None, 'persist': False, 'active_sub_modules': set(), 'ds_tensor.shape':
torch.Size([16416768])} to be inflight
This is due to multiple all-gather ops in the same coalesced all-gather
sharing the same list of params (of mixed dtypes).
Make each all-gather exchange only params of a certain dtype. Also pass
the allgather dtype that matches the params.
Signed-off-by: Junjie Mao <banxing.mjj@alibaba-inc.com>
Co-authored-by: Junjie Mao <banxing.mjj@alibaba-inc.com>
Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com>