Gradient accumulation fix in cross entropy loss #21386
Introduce peekable iterator to count number of valid tokens in the gl…
385fd56d
Scale loss by number of valid tokens in global batch in case of cross…
9b7aa6ff
Merge branch 'master' into bugfix/20350_grad_acc_fix
7f5f88cd
Ensure iterator is not None while passing to tee function
95d467d4
Merge branch 'bugfix/20350_grad_acc_fix' of https://github.com/Sohaib…
fb7dbc87
Merge branch 'master' into bugfix/20350_grad_acc_fix
01fcf620
Merge branch 'master' into bugfix/20350_grad_acc_fix
06216207
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub