Remove optimizer step on initialization (#5104)
All ZeRO 1/2/3 stages call the optimizer's `step()` on its
initialization. This increments a counter in the optimizer and produces
a different result in parameter update with the normal usage of PyTorch.
This PR eliminates `step()` in the initialization and lazily configures
some internal states (linking *hp_params*) after the first `step()`
call.
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>