fix: Restore explicit .keys() calls for TensorDict compatibility (#42373)
* fix: Restore explicit .keys() calls for TensorDict compatibility
Fixes issue where TensorDict objects cause RuntimeError: generator raised StopIteration
when used with data collators and tokenization utilities.
Problem:
- TensorDict.__iter__() iterates over batch dimensions instead of dictionary keys
- PR #37283 removed explicit .keys() calls, breaking TensorDict compatibility
- Affected DataCollatorWithPadding, DataCollatorForLanguageModeling, and other collators
Solution:
- Restored explicit .keys() calls in 5 critical locations where dict-list conversion happens
- Added len() check to handle empty batch edge case
- Changes are backward compatible and generalize to all Mapping objects
Files modified:
- src/transformers/tokenization_utils_base.py: Fixed pad() method
- src/transformers/tokenization_mistral_common.py: Fixed pad() method
- src/transformers/feature_extraction_sequence_utils.py: Fixed pad() method
- src/transformers/models/mluke/tokenization_mluke.py: Fixed pad() method
- src/transformers/models/luke/tokenization_luke.py: Fixed pad() method
Testing:
- Added comprehensive test suite: tests/trainer/test_tensordict_compatibility.py
- 7 test cases covering basic padding, variable lengths, mixed inputs, additional fields
- Added @require_tensordict decorator and is_tensordict_available() in testing_utils.py
- All existing tests pass (54/54 data collator tests, 2/2 padding tests)
Impact:
- Zero performance regression for standard dict usage
- Restores functionality for TensorDict and other Mapping implementations
- Fully backward compatible
* style: Apply ruff formatting to fix CI checks
* refactor: Address reviewer feedback on TensorDict fix
- Move TensorDict tests from standalone file to test_data_collator.py
- Simplify comments from verbose explanation to short reference: 'Call .keys() explicitly to avoid issue #42370'
- Delete tests/trainer/test_tensordict_compatibility.py (tests now in test_data_collator.py)
Addresses feedback from @ligz08 in PR review
* docs: Update comments to be self-explanatory about TensorDict compatibility
Changed from 'avoid issue #42370' to 'for compatibility with TensorDict and other Mapping subclasses'
so users don't need to look up the issue on GitHub to understand why .keys() is needed.
Addresses maintainer feedback.
* Remove TensorDict tests and utilities as requested
- Removed TensorDictCompatibilityTest class from test_data_collator.py
- Removed is_tensordict_available() and require_tensordict() from testing_utils.py
- TensorDict is not a CI dependency, so these tests would be skipped anyway
- The .keys() fix for TensorDict compatibility remains in place
---------
Co-authored-by: Pankaj Baid <baidpankaj567@gmail.com>