optimize grad_norm calculation in stage3.py (#4436)
reduce the synchronization between the device and the host by removing
.item() from the loops that calculate the total norm.
---------
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>