DeepSpeed
ea0d8114 - optimize grad_norm calculation in stage3.py (#4436)

Commit
1 year ago
optimize grad_norm calculation in stage3.py (#4436) reduce the synchronization between the device and the host by removing .item() from the loops that calculate the total norm. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com> Co-authored-by: Michael Wyatt <mrwyattii@gmail.com> Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com>
Author
Parents
Loading