Z3: optimizations for grad norm calculation and gradient clipping (#5504)
This PR add the below functionality:
1. complete_grad_norm_calculation_for_cpu_offload: move total_norm to
CPU, as expected device in such case is CPU..
2. repalce get_global_norm() with torch.linalg.norm for better
performance.
3. unscale_and_clip_grads: replace clipping based on if statement to use
torch.clamp for better performance.
change (3) is taken from
https://github.com/microsoft/DeepSpeed/pull/5547 (which was closed)
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
Co-authored-by: Liran Bachar <lbachar@habana.ai>