DeepSpeed
26ee3856 - [fix] fix test_zf.py hang bug (#8012)

Commit
20 days ago
[fix] fix test_zf.py hang bug (#8012) fix:When running following test command, it may hang. ```pytest -v runtime/zenflow/test_zf.py::TestZenFlowDistributed::test_zenflow_distributed[epoch-1-4-False-0-3]``` The reason is that when param.selected_indices got an empty result, its dtype would be torch.float32 instead of torch.int64. However, if the float32 empty tensor is used as an index just like grad_2d[param.selected_indices, :], it would cause a hang. So in order to solve this bug, I add a dtype cast to int64 when judge the param.selected_indices is empty, which means its original dtype is torch.float32. Signed-off-by: binchengxiong <binchengxiong@alibaba-inc.com> Co-authored-by: binchengxiong <binchengxiong@alibaba-inc.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Author
Parents
Loading