Add custom reshaping for universal checkpoint (#5390)
This PR adds more flexibility to define weight tensor reshaping for
universal checkpointing.
Currently universal checkpointing assumes a few patterns of partitioning
for tensor parallelism, such as column/row wise partitioning of a 2-dim
tensor. However, these are not flexible enough to define partitioning
for more complex usages. Here are some examples:
1) MoE: The user may define the weight tensor for MoE's FFN as
[n_experts * hidden_out, hidden_in]. For TP, we need to *view* this
tensor as 3-dim tensor and partition it along `hidden_out` dimension.
2) GQA: The weights for QKV are often represented as one tensor and we
may have Q, K and V with different sizes. The tensor shape will be
[q_size + k_size + v_size, hidden]. We partition this along first
dimension but for each Q, K, and V. In this case, we first need to
partition Q, V, and V separately and then concatenate them to get a
shard for TP.
We propose a new pattern `PARAMETER_WITH_SUB_PARAMS` to support this.
Here is the usage to cover the above use cases. You can define the view
of the weight tensor and specify the dimension for partitioning based on
the view.
```python
from deepspeed.checkpoint import PARAMETER_WITH_SUB_PARAMS, SubparamShape
info[PARAMETER_WITH_SUB_PARAMS] = [
asdict(SubparamShape(patterns=[layers_prefix + r"\d+moe.fc1.weight"],
shape=(num_experts, hidden_out, hidden_in), partition_dim=1)),
asdict(SubparamShape(patterns=[layers_prefix + r"\d+.qkv.weight"],
shape=((q_size, k_size, v_size), hidden_size), partition_dim=0)),
...
]
```
The conversion script (`ds_to_universal.py`) merges TP-sharded weight
tensors and the loader of universal checkpoints also partitions them
following the information.
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>