Fix potential memory issues when use deepspeed Z3 (#6726)
I had OOM problem when doing DPO training using zero3. It needs to call
module twice in one training step, and second call is with no_grad().
The problem is caused by two bugs:
1. "__n_available_params", which helps to control fetched parameters,
becomes negative after release_and_reset_all() function.
2. module.ds_grads_remaining becomes negative in backward() if we call
module more than once in one training step.
I tried to create two patches to fix these issues.
---------
Signed-off-by: Wenbin Chen <wenbin.chen@intel.com>
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>