Deduplicate fp32 weights under torch autocast and ZeRO3 (#7651)
When torch autocast is enabled, model weights are already in fp32 and
can be directly updated by the optimizer with fp32 gradients. It is a
waste of accelerator memory to keep another copy, also in fp32, as the
master weight.
Use aliases to the so-called-"fp16" params as the master weights to save
memory. It applies only when no optimizer offloading (either CPU or
NVMe) or swapping mechanisms is enabled.
Using
https://gist.github.com/eternalNight/3c2cf8c703f1e9e7742d3b7f9e1edae3
(which enables torch autocast) as an example, the memory profile of the
training startup phase is as follows:
<img width="3172" height="1915" alt="Picture1"
src="https://github.com/user-attachments/assets/ffd40042-3582-4c82-9072-e1fdf8d49a63"
/>
With this PR, the master weights no longer instantiate:
<img width="2990" height="1753" alt="Picture2"
src="https://github.com/user-attachments/assets/1d1d3411-0735-4bd1-8061-3e015040ce74"
/>
This is also true when DeepCompile is enabled:
<img width="3094" height="2083" alt="Picture3"
src="https://github.com/user-attachments/assets/c867d766-769a-4775-ac2a-3f1a1a723c32"
/>
When torch autocast is disabled, the master weights are preserved:
<img width="2922" height="1471" alt="Picture4"
src="https://github.com/user-attachments/assets/5097ef57-2c7a-4fd0-b0c3-717c098ec52c"
/>
Signed-off-by: Junjie Mao <junjie.mao@linux.alibaba.com>