transformers
1bca9bac - fix: Restore explicit .keys() calls for TensorDict compatibility (#42373)

Commit
21 days ago
fix: Restore explicit .keys() calls for TensorDict compatibility (#42373) * fix: Restore explicit .keys() calls for TensorDict compatibility Fixes issue where TensorDict objects cause RuntimeError: generator raised StopIteration when used with data collators and tokenization utilities. Problem: - TensorDict.__iter__() iterates over batch dimensions instead of dictionary keys - PR #37283 removed explicit .keys() calls, breaking TensorDict compatibility - Affected DataCollatorWithPadding, DataCollatorForLanguageModeling, and other collators Solution: - Restored explicit .keys() calls in 5 critical locations where dict-list conversion happens - Added len() check to handle empty batch edge case - Changes are backward compatible and generalize to all Mapping objects Files modified: - src/transformers/tokenization_utils_base.py: Fixed pad() method - src/transformers/tokenization_mistral_common.py: Fixed pad() method - src/transformers/feature_extraction_sequence_utils.py: Fixed pad() method - src/transformers/models/mluke/tokenization_mluke.py: Fixed pad() method - src/transformers/models/luke/tokenization_luke.py: Fixed pad() method Testing: - Added comprehensive test suite: tests/trainer/test_tensordict_compatibility.py - 7 test cases covering basic padding, variable lengths, mixed inputs, additional fields - Added @require_tensordict decorator and is_tensordict_available() in testing_utils.py - All existing tests pass (54/54 data collator tests, 2/2 padding tests) Impact: - Zero performance regression for standard dict usage - Restores functionality for TensorDict and other Mapping implementations - Fully backward compatible * style: Apply ruff formatting to fix CI checks * refactor: Address reviewer feedback on TensorDict fix - Move TensorDict tests from standalone file to test_data_collator.py - Simplify comments from verbose explanation to short reference: 'Call .keys() explicitly to avoid issue #42370' - Delete tests/trainer/test_tensordict_compatibility.py (tests now in test_data_collator.py) Addresses feedback from @ligz08 in PR review * docs: Update comments to be self-explanatory about TensorDict compatibility Changed from 'avoid issue #42370' to 'for compatibility with TensorDict and other Mapping subclasses' so users don't need to look up the issue on GitHub to understand why .keys() is needed. Addresses maintainer feedback. * Remove TensorDict tests and utilities as requested - Removed TensorDictCompatibilityTest class from test_data_collator.py - Removed is_tensordict_available() and require_tensordict() from testing_utils.py - TensorDict is not a CI dependency, so these tests would be skipped anyway - The .keys() fix for TensorDict compatibility remains in place --------- Co-authored-by: Pankaj Baid <baidpankaj567@gmail.com>
Author
Parents
Loading