Add APIs to offload states of model, optimizer, and engine (#6011)
This PR adds the following APIs to offload model, optimizer, and engine
states.
```pytyon
def offload_states(self,
include: Container[OffloadStateTypeEnum] = None,
device: OffloadDeviceEnum = OffloadDeviceEnum.cpu,
pin_memory: bool = True,
non_blocking: bool = False) -> None:
"""Move the ZeRO optimizer buffers to the specified device.
Arguments:
include: Optional. The set of states to offload. If not provided, all states are offloaded.
device: Optional. The device to move the ZeRO optimizer buffers to.
pin_memory: Optional. Whether to pin the memory of the offloaded states.
non_blocking: Optional. Whether to offload the states asynchronously.
...
def offload_states_back(self, non_blocking: bool = False) -> None:
```
Here is the typical usage.
```python
# Offload after forward, backward, and step
model.offload_states()
# Do something requiring a lot of device memory
...
# Load states back to device memory
model.offload_states_back()
```
You can selectively offload states to balance the offloading overhead
and memory saving.
```python
model.offload_states(include=set([OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.opt_states], device=OffloadDeviceEnum.cpu)
```
Performance (4.3B parameters / 4x A100)
- Environment (4x A100, [benchmark
script](https://gist.github.com/tohtana/05d5faba5068cf839abfc7b1e38b85e4))
- Average Device to Host transfer time: 2.45 GB/s, aggregated: 9.79 GB/s
- Average Host to Device transfer: 11.05 GB/s, aggregated: 44.19 GB/s
- Mem (allocated by PyTorch)
- Before offload 18.2GB
- After offloading 17.7MB
- Time ([benchmark
script](https://github.com/microsoft/DeepSpeedExamples/tree/tohtana/offload_states/training/offload_states),
offloading time/loading time)
python output_table.py
| |pin_memory=0 non_blocking=0|pin_memory=0 non_blocking=1|pin_memory=1
non_blocking=0|pin_memory=1 non_blocking=1|
|--:|---------------------------|---------------------------|---------------------------|---------------------------|
| 1|4.34 / 3.42 |4.99 / 2.37 |6.5 / 2.42 |6.0 / 2.39 |
| 2|9.9 / 3.28 |5.1 / 2.34 |6.21 / 2.42 |6.25 / 2.45 |
| 3|9.92 / 3.19 |6.71 / 2.35 |6.33 / 2.38 |5.93 / 2.42 |
| 4|9.55 / 2.82 |7.11 / 2.39 |6.9 / 2.38 |6.5 / 2.43 |
| 5|4.4 / 3.35 |6.04 / 2.41 |6.26 / 2.41 |6.32 / 2.47 |
| 6|4.4 / 3.57 |6.58 / 2.42 |6.88 / 2.4 |6.35 / 2.43 |
| 7|9.51 / 3.12 |6.9 / 2.39 |6.9 / 2.39 |6.46 / 2.4 |
| 8|4.77 / 3.64 |6.69 / 2.39 |7.39 / 2.42 |6.56 / 2.46 |
| 9|9.5 / 3.07 |7.18 / 2.42 |6.67 / 2.39 |7.38 / 2.46 |
TODO:
- Enable offloading to a NVMe storage -> NVMe support is non-trivial. I
suggest adding the support in another PR
- [DONE] Discard buffer (and recreate it) instead of offloading. We
don't need to restore the contiguous buffer for reduce.
- [DONE] Check pin_memory improves performance or not
---------
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>