XLA train step fixes #17973
Copy inputs to train and test step before modifying them, as this bre…
b924b407
Add XLA tests, fix our loss functions to be XLA-compatible
0414cedc
make fixup
167fd324
gante
commented
on 2022-07-01
Update loss computation test to expect vector of per-sample losses
e01286d3
Patch loss for TFLED
3e537933
Patch loss for TFAlbert
43ce3f58
Add a tf_legacy_loss config flag that enables old loss functions
4060777a
Stop using config.get() because it's not a dict
391b050d
Skip loss computation test for RAG because its loss is very strange a…
8035a272
make fixup
58e3db87
sgugger
approved these changes
on 2022-07-01
Add XLA-compatible RAG loss
3b9fe743
Fix dtype of loss mask for TFAlbert
db79798f
Fix test for XLNet too because it overrides the default one
9a6b7b58
make fixup
92a4e798
Fix config test
6021439b
No more depending on GPU NaN behaviour
a46da255
Add test, avoid potential zero division
64c0e77e
Fix test item assignment
d34a3b2f
Fix loss computation masking test
32078b24
make fixup
a19ee4fd
Fix dtype bugs
f17136c8
Rocketknight1
deleted the xla_train_step_fixes branch 3 years ago
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub