A faster and more memory-efficient implementation of `zero_to_fp32` (#6658)
It is a faster and more memory-efficient implementation of
`zero_to_fp32`.
The previous version double the memory usage, which cause cpu OOM for
very large models (e.g. llama 405B).
https://github.com/microsoft/DeepSpeed/blob/b647fb2470f8f6fefe5cab0ea84a2d89696eb898/deepspeed/utils/zero_to_fp32.py#L438-L441
## How does it work?
1. **Lazy loading**: Load checkpoint with `mmap=True`, thus the weights
are mmaped rather than loading all the storages into memory.
2. **Lazy merge**: `GatheredTensor` contains the mmaped weights and
tensor offset. It is a memory-efficient pseudo tensor. Only when
`tensor.contiguous()` is called, it starts to load related weights to
memory and merge into a single tensor.
3. **Release memory in time**: Save checkpoints shard by shard, and
release the memory once a shard is saved.
Throughout the process, only one shard of tensors are keeped in memory.
## How much benefit in speed and memory ?
Experiments were conducted on a linux host with 1TB of memory. Here is a
detailed comparision
| | world size | peak memory(GB) | elapsed time(h:mm:ss) |
|----------------------|------------|--------------|--------------------|
| llama3-8B(old->new) | 8 | 90 -> 41 | 0:02:17 -> 0:01:10 |
| llama2-13B(old->new) | 8 | 146 -> 54 | 0:02:30 -> 0:01:47 |
| llama2-70B(old->new) | 16 | 789 -> 159 | 0:20:47 -> 0:20:45 |
| qwen1.5-110B(old->new) | 32 | OOM -> 217 | ? -> 0:34:21 |
| llama3-405B(old->new) | 192 | OOM -> 262 | ? -> 2:09:59 |
You can reproduce with the following scripts
```sh
# 1. install requirments
apt-get install time
# 2. prepare zero-3 checkpoints
# 3. convert zero to fp32 checkpoints
/usr/bin/time -v python zero_to_fp32.py . output_dir/ --safe_serialization
```
- **memory**: Theoretically, this PR reduces the memory cost from `2M`
to `(1/n)M`, where `M` is the memory cost of the full weights, `n` is
num_shards.
- **speed**: The speed gain mainly comes from avoiding extra tensor
copying. The benifit may be slight.
## Impl history
-
[v1](https://github.com/xu-song/DeepSpeed/commit/19712a1c75bfc1da4a7f3ecca6915a86af671568#diff-6a2ca3427fa608c387b7351359f98cfc1313be6e960cee86344ff246bf1b8326R441-R447)
: a hf_hub compatible approach.
It has been discarded due to the controversial implementation of
`data_ptr().`
- [v2](https://github.com/microsoft/DeepSpeed/pull/6658/files): a simple
approach with `torch.empty`
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>