Z1/2 init: flatten params on device (#7828)
This PR addresses #7677 by flattening parameter tensors on the
accelerators instead of the CPU during zero stage 1 and 2
initialization. This should alleviate CPU contention, with the caveat
that the optimization is only used when there is enough VRAM to allocate
a full copy of the parameter buffers.
On 8 x H100s and a Intel Xeon Platinum 8480+, profiling the
initialization of DeepSpeed on 32 layers of `Qwen3-30B` with Z2 gives
the following:
Old = ~382s
New = ~130s
-------------------------
If necessary, this optimization can be extended to allowed a tiered
system that trades off VRAM space with performance, which might look
like the following:
```
if enough VRAM for 2x model_size:
naive flatten
else if enough VRAM for model_size / N:
distributed flatten across N devices
else:
flatten on CPU
```
The distributed flatten would involve each device flattening a portion
of the parameters and performing an all-gather to assemble the full
flattened model. See #7677 for original discussion.
---------
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: leejianwoo-collab <leejianwoo@gmail.com>
Signed-off-by: vensen <vensenmu@gmail.com>
Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Co-authored-by: nathon <leejianwoo@gmail.com>
Co-authored-by: Vensen <vensenmu@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: jp <jsb10121249@gmail.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>