DeepSpeed
dd402694 - A faster and more memory-efficient implementation of `zero_to_fp32` (#6658)

Commit
1 year ago
A faster and more memory-efficient implementation of `zero_to_fp32` (#6658) It is a faster and more memory-efficient implementation of `zero_to_fp32`. The previous version double the memory usage, which cause cpu OOM for very large models (e.g. llama 405B). https://github.com/microsoft/DeepSpeed/blob/b647fb2470f8f6fefe5cab0ea84a2d89696eb898/deepspeed/utils/zero_to_fp32.py#L438-L441 ## How does it work? 1. **Lazy loading**: Load checkpoint with `mmap=True`, thus the weights are mmaped rather than loading all the storages into memory. 2. **Lazy merge**: `GatheredTensor` contains the mmaped weights and tensor offset. It is a memory-efficient pseudo tensor. Only when `tensor.contiguous()` is called, it starts to load related weights to memory and merge into a single tensor. 3. **Release memory in time**: Save checkpoints shard by shard, and release the memory once a shard is saved. Throughout the process, only one shard of tensors are keeped in memory. ## How much benefit in speed and memory ? Experiments were conducted on a linux host with 1TB of memory. Here is a detailed comparision | | world size | peak memory(GB) | elapsed time(h:mm:ss) | |----------------------|------------|--------------|--------------------| | llama3-8B(old->new) | 8 | 90 -> 41 | 0:02:17 -> 0:01:10 | | llama2-13B(old->new) | 8 | 146 -> 54 | 0:02:30 -> 0:01:47 | | llama2-70B(old->new) | 16 | 789 -> 159 | 0:20:47 -> 0:20:45 | | qwen1.5-110B(old->new) | 32 | OOM -> 217 | ? -> 0:34:21 | | llama3-405B(old->new) | 192 | OOM -> 262 | ? -> 2:09:59 | You can reproduce with the following scripts ```sh # 1. install requirments apt-get install time # 2. prepare zero-3 checkpoints # 3. convert zero to fp32 checkpoints /usr/bin/time -v python zero_to_fp32.py . output_dir/ --safe_serialization ``` - **memory**: Theoretically, this PR reduces the memory cost from `2M` to `(1/n)M`, where `M` is the memory cost of the full weights, `n` is num_shards. - **speed**: The speed gain mainly comes from avoiding extra tensor copying. The benifit may be slight. ## Impl history - [v1](https://github.com/xu-song/DeepSpeed/commit/19712a1c75bfc1da4a7f3ecca6915a86af671568#diff-6a2ca3427fa608c387b7351359f98cfc1313be6e960cee86344ff246bf1b8326R441-R447) : a hf_hub compatible approach. It has been discarded due to the controversial implementation of `data_ptr().` - [v2](https://github.com/microsoft/DeepSpeed/pull/6658/files): a simple approach with `torch.empty` --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Author
Parents
Loading