DeepSpeed
af033831 - Release overlap_comm & contiguous_gradients restrictions for ZeRO 1 (#4887)

Commit
1 year ago
Release overlap_comm & contiguous_gradients restrictions for ZeRO 1 (#4887) The `overlap_comm` and `contiguous_gradients` options have been ignored in ZeRO stage 1 since https://github.com/microsoft/DeepSpeed/pull/1246. Back in that time, ZeRO 1 and 2 are separately implemented (see https://github.com/microsoft/DeepSpeed/tree/6ae756c03f12674f17aef90622e7664a8af9d2af/deepspeed/runtime/zero). ZeRO 1 does not have gradient hooks registered to overlap backward and gradient all-reduce, so it's fine to ignore `overlap_comm` and `contiguous_gradients`. However, in the current implementation, ZeRO 1 and 2 share almost the same implementation (`stage_1_and_2.py`). Features like `overlap_comm` and `contiguous_gradients` can also be enabled for ZeRO 1 (Please correct me if I made a mistake). With this PR, turning on `overlap_comm` and `contiguous_gradients` for ZeRO 1 on the [SFT task](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning) produces exactly the same training curve as the latest master. ![image](https://github.com/microsoft/DeepSpeed/assets/39846316/bda3be7b-c236-4e08-b687-b3cd01f5cc73) I also see a ~1.05x e2e speedup by overlapping backward and gradient all-reduce. I can confirm by the trace that backward and all-reduce do overlap, and the separate gradients are indeed copied to a flat buffer. These options are also effective for ZeRO 1. ![image](https://github.com/microsoft/DeepSpeed/assets/39846316/5f876296-e1b4-404b-8b33-03cee8e5e6b2) ![image](https://github.com/microsoft/DeepSpeed/assets/39846316/9654f6be-5c7a-401a-b0bc-413ecd3f4e6b) Related issue: https://github.com/microsoft/DeepSpeed/issues/2295 Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Author
Parents
Loading