Enable hpz based on secondary tensor presence (#4906)
Previously we use a series of forward/backward flags to control if hpz
should be enabled on certain allgather call. This PR simplifies this by
enabling hpz only when its secondary tensor exists (and invalidating its
secondary tensor whenever master weights changes). This should:
1. Prevent potential out-of-sync issue compared with our currently way
of overwriting secondary tensor
2. Improve throughput because now hpz will be enabled in a lot of
different scenarios including i) activation checkpointing, ii) gradient
accumulation, iii)`torch.no_grad` context, iv) `model.eval()` mode,
v)LoRA frozen weights, vi) gradient overflow
This is to fix https://github.com/microsoft/DeepSpeed/issues/4851
Convergence test:
- llama-2-7b random weights, using
https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2/run_llama2_7b.sh.
> zero-3 Baseline: Evaluating perplexity, Epoch 4/4: ppl:
5.151907920837402, loss: 1.6393671035766602
> hpz with this PR: ppl: 5.081737518310547, loss: 1.6256532669067383
- llama-2-7b pretrained weights with lora, using
https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2/run_llama2_7b_lora.sh.
> zero-3 Baseline: Evaluating perplexity, Epoch 4/4: ppl:
1.8326854705810547, loss: 0.6057823896408081
> hpz with this PR: ppl: 1.8326854705810547, loss: 0.6057823896408081
Performance test on 32 V100, still using
https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/training_scripts/llama2/run_llama2_7b.sh.
- gradient accumulation step = 8
> master branch with hpz: SamplesPerSec=17.567813158654847
> this patch with hpz: SamplesPerSec=24.121657876029225
- lora
> master branch with hpz: SamplesPerSec=33.88883430864484
> this patch with hpz: SamplesPerSec=43.39463460004735
---------
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>