sequence parallel with communication overlap (#5691)
SP is a fantastic piece of work, it is very elegant and conciseļ¼ at the
current stage, a transformer layer's forward and backward passes involve
8 all-to-all operations, with 5 opportunities for overlapping
communication:
Forward pass: The QKV matrix operations can be pipelined alongside some
of the all-to-all communications.
Backward pass: DQ, DK, DV all-to-all communications can be pipelined
alongside matrix operations.
Backward pass: DO_w can be parallel with DO_input, involving matrix
operations and all-to-all communications. Similar overlap-comm
strategies are used in Megatron for TP/TP-sp parallelism.
I tested under conditions of 1N8C zero1, disabled activation
checkpointing, ds-sp=8, and gbs=16:
1B 64K
7B 16K
They showed over 10% improvement (where I found that for mega-ds, using
split QKV itself can also enhance performance due to reducing slice +
cat operations in fwd/bwd), despite some TFLOPs already performing at a
relatively good level.
co-work with https://github.com/microsoft/Megatron-DeepSpeed/pull/415
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Heyang Qin <heyangqin@microsoft.com>