Delay reduce-scatter for ZeRO3 leaf modules (#5008)
ZeRO3 sets hooks on parameters to run reduce-scatter. This is often
problematic for MoE models. Our data parallel processes may activate
different sets of experts, but the hook is not fired unless the expert
is activated at a forward pass. The reduce-scatter is called only on
some processes in this case.
This PR delays reduce-scatter for ZeRO3 leaf modules (Refer to #4966) to
address the issue.
We no longer set reduce-scatter hooks on parameters of the leaf modules.
Instead, we launch reduce-scatter on all parameters belonging to the
leaf module when exiting the module during the backward pass.
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>