transformers
909ece88 - Fix slow Trainer path with 4D attention mask (#45852)

Commit
2 days ago
Fix slow Trainer path with 4D attention mask (#45852) * Fix slow Trainer path with 4D attention mask Fixes #32101 PreTrainedTokenizerBase.pad runs higher-rank per-sample inputs through to_py_obj and rebuilds them with torch.tensor on a deeply nested Python list, which is the dominant cost on the 4D-mask path. _pad does not operate on rank > 1 per-sample inputs anyway, so the round trip is wasted work. Pop a multi-dim per-sample attention_mask out before the per-sample padding loop and stack it back with torch.stack / np.stack at the end. End-to-end on the OP's repro (CPU, batch=32, grad_accum=32, seq=1024, 1 optimizer step), the 4D Trainer step drops from 225 s to 115 s, within 13% of the 2D path. Collator-only at the same shape: ~3.65 s per batch to ~11.8 ms. * Update src/transformers/tokenization_utils_base.py * fix lint --------- Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Author
Parents
Loading