Fixing the reshape bug in sequence parallel alltoall, which corrupted all QKV data (#5664)
Currently in the implementation of DeepSpeed sequence parallel, two
`All_to_all` are used in the distributed attention to scatter and gather
sequence. However, the `reshape` operation is wrong in the second
[`All_to_all`](https://github.com/microsoft/DeepSpeed/blob/eda5075b88c448d13009301dc73653a224bb24b8/deepspeed/sequence/layer.py#L36).
**The model will never converge as the data is corrupted by it.**
To easily check the problem with current implementation, we can do the
following to [this
line](https://github.com/microsoft/DeepSpeed/blob/eda5075b88c448d13009301dc73653a224bb24b8/deepspeed/sequence/layer.py#L104):
```
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
""" forward
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
args: other args
Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
#in shape : e.g., [s/p:h:]
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
#out shape : e.g., [s:h/p:]
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
#out e.g., [s/p::h]
return output
```
Remove the attention computation, leaving only the `all_to_all`. And we
just check the `query` before and after `all_to_all`, which should be
the same.
```
query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
#out shape : e.g., [s:h/p:]
#context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) # do not perform attn,
context_layer = query_layer # just use the input query
output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
if torch.distributed.get_rank() == 3:
print(query[0][15730][5])
print(output[0][15730][5])
```
**_In current implementation, `all_to_all` totally messes up the data.
The printed values in query are misaligned with output_**
The problem is because of this incorrect
[reshape](https://github.com/microsoft/DeepSpeed/blob/eda5075b88c448d13009301dc73653a224bb24b8/deepspeed/sequence/layer.py#L36C18-L36C24):
```
def single_all_to_all(input, scatter_idx, gather_idx, group):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
if scatter_idx < 2:
input_t = input.reshape(
[seq_world_size, inp_shape[scatter_idx]] + \
inp_shape[scatter_idx + 1:]
).contiguous()
else:
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
input_t = input.reshape(
[-1, seq_world_size, inp_shape[scatter_idx]] + \
inp_shape[scatter_idx + 1:]
).transpose(0, 1).contiguous()
output = torch.empty_like(input_t)
dist.all_to_all_single(output, input_t, group=group)
# if scattering the seq-dim, transpose the heads back to the original dimension
if scatter_idx < 2:
output = output.transpose(0, 1).contiguous()
return output.reshape(
inp_shape[: gather_idx] + \
[inp_shape[gather_idx] * seq_world_size,] + \
inp_shape[gather_idx + 1:]).contiguous()
```
When performing the second `all_to_all`, the
[output](https://github.com/microsoft/DeepSpeed/blob/eda5075b88c448d13009301dc73653a224bb24b8/deepspeed/sequence/layer.py#L32)
we gathered from other ranks is of shape:
```
dist.all_to_all_single(output, input_t, group=group)
# output: [seq_world_size, batch, local_seq_length, num_local_heads, head_dim]
if scatter_idx < 2:
output = output.transpose(0, 1).contiguous()
# output: [batch, seq_world_size, local_seq_length, num_local_heads, head_dim]
```
At this step, we actually want to gather all the heads of the local
sequence, therefore, the above line needs to be:
```
if scatter_idx < 2:
output = output.transpose(0, 2)
# output: [batch, local_seq_length, seq_world_size, num_local_heads, head_dim]
```
Only by doing this, can we then:
```
return output.reshape(
inp_shape[: gather_idx] + \
[inp_shape[gather_idx] * seq_world_size,] + \
inp_shape[gather_idx + 1:]).contiguous()
```
which then arranges the data correctly.
A more straight-forward example is:
```
# second all_to_all
# batch: 1
# sequence parallel size: 4
# local sequence length: 8192
# total number of heads: 16
# head dim: 128
dist.all_to_all_single(output, input_t, group=group)
# output: [4, 1, 8192, 4, 128]
if scatter_idx < 2:
output = output.transpose(0, 1).contiguous()
# output: [1, 4, 8192, 4, 128]
# At this step, you cannot directly reshape it into [1, 8192, 16, 128] as it corrupts the data.
# You need to permute output into [1, 8192, 4, 4, 128], then reshape it into [1, 8192, 16, 128].
```
For the first `all_to_all`, things work fine. This issue only exists in
the second `all_to_all`.
Co-authored-by: Jinghan Yao <jyao@athena.nic.uoregon.edu>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>