enable yuan autotp & add conv tp (#5428)
This PR aims to enable yuan model autotp and add conv tp.
Yuan model used shared qk.
For example:
q_linear_out = [q1, q2, q3, q4, q5, ... , q16]
k_linear_out = [k1, k2, k3, k4, k5, ... , k16]
after share qk:
TP=1:
q' = [q1,q2,q3,q4, q9,q10,q11,q12, k1,k2 k3,k4, k9,k10,k11,k12]
k' = [q5,q6,q7,q8, q13,q14,q15,q16, k5,k6,k7,k8, k13,k14,k15,k16]
v' = [v1,v2,v3,v4, v5,v6,v7,v8, v9,v10,v11,v12, v13,v14,v15,v16]
TP=2:
rank0:
q'_0 = [q1,q2,q3,q4, k1,k2 k3,k4]
k'_0 = [q5,q6,q7,q8, k5,k6,k7,k8]
v'_0 = [v1,v2,v3,v4, v5,v6,v7,v8] -> v'_0 is error! Expect value is:
[v1,v2,v3,v4, v9,v10,v11,v12]
rank1:
q'_1 = [q9,q10,q11,q12, k9,k10,k11,k12]
k'_1 = [q13,q14,q15,q16, k13,k14,k15,k16]
v'_1 = [v9,v10,v11,v12, v13,v14,v15,v16] -> v'_1 is error! Expect value
is: [v5,v6,v7,v8, v13,v14,v15,v16]
To avoid modifying the modeling code. We adjust the value and oproj
weight to fit this qk type.
We also added the conv tp to support some models that including the
heavy conv calculation. It is similar to the linear tp policy.
if not last_conv_layer:
- 1. Divide the conv weight to each rank along the output channel
dimension.
- 2. To apply conv2d.
else:
- 1. Divide the conv weight to each rank along the input channel
dimension.
- 2. Apply conv2d.
- 3. Use allreduce to add outputs.
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>