DeepSpeed
3bdd187e - Fixing the reshape bug in sequence parallel alltoall, which corrupted all QKV data (#5664)

Commit
1 year ago
Fixing the reshape bug in sequence parallel alltoall, which corrupted all QKV data (#5664) Currently in the implementation of DeepSpeed sequence parallel, two `All_to_all` are used in the distributed attention to scatter and gather sequence. However, the `reshape` operation is wrong in the second [`All_to_all`](https://github.com/microsoft/DeepSpeed/blob/eda5075b88c448d13009301dc73653a224bb24b8/deepspeed/sequence/layer.py#L36). **The model will never converge as the data is corrupted by it.** To easily check the problem with current implementation, we can do the following to [this line](https://github.com/microsoft/DeepSpeed/blob/eda5075b88c448d13009301dc73653a224bb24b8/deepspeed/sequence/layer.py#L104): ``` def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor: """ forward Arguments: query (Tensor): query input to the layer key (Tensor): key input to the layer value (Tensor): value input to the layer args: other args Returns: * output (Tensor): context output """ # TODO Merge three alltoall calls into one # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! #in shape : e.g., [s/p:h:] query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) #out shape : e.g., [s:h/p:] context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) #out e.g., [s/p::h] return output ``` Remove the attention computation, leaving only the `all_to_all`. And we just check the `query` before and after `all_to_all`, which should be the same. ``` query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx) #out shape : e.g., [s:h/p:] #context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) # do not perform attn, context_layer = query_layer # just use the input query output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx) if torch.distributed.get_rank() == 3: print(query[0][15730][5]) print(output[0][15730][5]) ``` **_In current implementation, `all_to_all` totally messes up the data. The printed values in query are misaligned with output_** The problem is because of this incorrect [reshape](https://github.com/microsoft/DeepSpeed/blob/eda5075b88c448d13009301dc73653a224bb24b8/deepspeed/sequence/layer.py#L36C18-L36C24): ``` def single_all_to_all(input, scatter_idx, gather_idx, group): seq_world_size = dist.get_world_size(group) inp_shape = list(input.shape) inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size if scatter_idx < 2: input_t = input.reshape( [seq_world_size, inp_shape[scatter_idx]] + \ inp_shape[scatter_idx + 1:] ).contiguous() else: # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! input_t = input.reshape( [-1, seq_world_size, inp_shape[scatter_idx]] + \ inp_shape[scatter_idx + 1:] ).transpose(0, 1).contiguous() output = torch.empty_like(input_t) dist.all_to_all_single(output, input_t, group=group) # if scattering the seq-dim, transpose the heads back to the original dimension if scatter_idx < 2: output = output.transpose(0, 1).contiguous() return output.reshape( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ inp_shape[gather_idx + 1:]).contiguous() ``` When performing the second `all_to_all`, the [output](https://github.com/microsoft/DeepSpeed/blob/eda5075b88c448d13009301dc73653a224bb24b8/deepspeed/sequence/layer.py#L32) we gathered from other ranks is of shape: ``` dist.all_to_all_single(output, input_t, group=group) # output: [seq_world_size, batch, local_seq_length, num_local_heads, head_dim] if scatter_idx < 2: output = output.transpose(0, 1).contiguous() # output: [batch, seq_world_size, local_seq_length, num_local_heads, head_dim] ``` At this step, we actually want to gather all the heads of the local sequence, therefore, the above line needs to be: ``` if scatter_idx < 2: output = output.transpose(0, 2) # output: [batch, local_seq_length, seq_world_size, num_local_heads, head_dim] ``` Only by doing this, can we then: ``` return output.reshape( inp_shape[: gather_idx] + \ [inp_shape[gather_idx] * seq_world_size,] + \ inp_shape[gather_idx + 1:]).contiguous() ``` which then arranges the data correctly. A more straight-forward example is: ``` # second all_to_all # batch: 1 # sequence parallel size: 4 # local sequence length: 8192 # total number of heads: 16 # head dim: 128 dist.all_to_all_single(output, input_t, group=group) # output: [4, 1, 8192, 4, 128] if scatter_idx < 2: output = output.transpose(0, 1).contiguous() # output: [1, 4, 8192, 4, 128] # At this step, you cannot directly reshape it into [1, 8192, 16, 128] as it corrupts the data. # You need to permute output into [1, 8192, 4, 4, 128], then reshape it into [1, 8192, 16, 128]. ``` For the first `all_to_all`, things work fine. This issue only exists in the second `all_to_all`. Co-authored-by: Jinghan Yao <jyao@athena.nic.uoregon.edu> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Author
Parents
Loading