DeepSpeed
706f6e82 - Deduplicate fp32 weights under torch autocast and ZeRO3 (#7651)

Commit
87 days ago
Deduplicate fp32 weights under torch autocast and ZeRO3 (#7651) When torch autocast is enabled, model weights are already in fp32 and can be directly updated by the optimizer with fp32 gradients. It is a waste of accelerator memory to keep another copy, also in fp32, as the master weight. Use aliases to the so-called-"fp16" params as the master weights to save memory. It applies only when no optimizer offloading (either CPU or NVMe) or swapping mechanisms is enabled. Using https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3 (which enables torch autocast) as an example, the memory profile of the training startup phase is as follows: <img width="3172" height="1915" alt="Picture1" src="https://github.com/user-attachments/assets/ffd40042-3582-4c82-9072-e1fdf8d49a63" /> With this PR, the master weights no longer instantiate: <img width="2990" height="1753" alt="Picture2" src="https://github.com/user-attachments/assets/1d1d3411-0735-4bd1-8061-3e015040ce74" /> This is also true when DeepCompile is enabled: <img width="3094" height="2083" alt="Picture3" src="https://github.com/user-attachments/assets/c867d766-769a-4775-ac2a-3f1a1a723c32" /> When torch autocast is disabled, the master weights are preserved: <img width="2922" height="1471" alt="Picture4" src="https://github.com/user-attachments/assets/5097ef57-2c7a-4fd0-b0c3-717c098ec52c" /> Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>
Author
Parents
Loading