transformers
Add TDT loss kernel
#46048
Open
Go
Login via GitHub
Home
Pricing
FAQ
Install
Login
via GitHub
Overview
Commits
82
Changes
View On
GitHub
Add TDT loss kernel
#46048
ebezzam
wants to merge 82 commits into
huggingface:main
from
ebezzam:tdt_loss_kernel
parakeet tdt intergration
fa7d6e0e
Add TDT decoder support for Parakeet ASR models
f2b49380
Add expected outputs for TDT, small fixes.
fa36657f
Separate CTC and TDT generate outputs.
05e2e346
Work with auto device, better init,
bb5ff331
Test timestamps and expose token duration.
9ec79b02
Add reproducer link.
33f128ec
fix: align TDT training and decoding with NeMo implementation
760b4b61
revert: restore lasr generated files to original state
b33002fc
warn: torchaudio rnnt_loss does not train duration head
48b39dd1
Relax timestamp test, and test nits.
e9f23ab6
feat: TDT training
e2b97aa1
chore: for cuda detection and run without patching
6b9fc731
Equivalent timestamp processing as Nemo, and various nits/cleanup.
6c879bc0
Merge branch 'parakeet-tdt' of github.com:lmaksym/transformers into p…
149e17f4
Simplify durations config.
36bfa639
Update training examples.
2df0ccca
chore: enable parralelism
388c6d36
chore: performance optimization
08b2b558
fix: formatting
0c4e05a8
Doc and testing nits
1ddd8049
Use active mask from current step, and nits.
f5126703
Better pre-allocate.
07d8e35e
TDT has separate pad token and blank token.
fab050a3
Merge branch 'main' into parakeet-tdt
c438565c
Regenerate lasr.
86d980c1
Merge branch 'parakeet-tdt' of github.com:lmaksym/transformers into p…
895c4a0a
Style checks and nits
ab21380b
Nits, put back ctc loss test
d0141d5f
More standard model output.
f7529d41
Style
77b95d73
Remove compute_loss flag and allow monkey patching to tdt loss
94eae66f
Update src/transformers/models/parakeet/modular_parakeet.py
f7d40675
Address various comments.
f75c17b6
More compatible with Transformers forward/generate approach
5a49b651
compile option for generation and decoder cache
881233fd
Cleaner, better conventions.
b41a8ee6
Merge branch 'main' into parakeet-tdt
897753a0
Update with main.
6c914dbe
doc nits
756cee1e
Imitate whisper for encoder outputs as input
f30c5364
Address tests and nits.
fa95fc8e
Inherit from GenerateMixIn for get_compiled_call
5df7f289
Comment nit
cd706d48
forward cleanup
a47ed8a5
generate cleanup + separate generation file
13b68cec
generate: add _supported_generation_modes
72c1ad00
automatic init of the loss
8e23b3df
modular cleanups
1cc39fd8
use is_encoder_decoder
531f297e
timestamp processing fully from tokens + durations
2c0f23af
convertion script update
cef6639e
test update
fd3cf9b2
make
e63a5bf1
Merge branch 'main' into parakeet-tdt
f9d1a4fc
test update
43ee7cd7
test update
c2a0f781
ensure correct loss computation
1fd7ed78
kernel loss
7cc9d2e7
test loss integration
e753eab1
push to hub pr
ed3fa4dc
integration tests to rely fully on transcripts
ab66b239
udpate fixtures
a5ba0c61
we don't need to monkey patch with numba anymore!
48279a67
fix pipeline usage
1d7680d4
nit
59ddcedb
fix usage
31490d19
Pass through tests and examples: improve kernel fallback, update with…
d8eb1b6f
Update checkpoint
1f1b912d
Merge branch 'main' into parakeet-tdt
9ab08d1e
Add TDT to mapping after merge.
fd9f8b1b
Fix lasr generate test.
136f6768
Output attention mask if labels provided for computing loss.
833d2890
Apply suggestion from @ArthurZucker
a1c62a1f
Improve ParakeetTDTDecoderCache definition and usage.
86835704
Remove tuple parsing.
1d4b0f43
processor refactor
a418ecae
Merge branch 'parakeet-tdt' of github.com:lmaksym/transformers into p…
5d0c6318
Update conversion.
5c603c17
ebezzam
marked this pull request as draft
38 days ago
Merge branch 'main' into tdt_loss_kernel
09ba99c4
Modular after merge.
e743b2d2
ebezzam
commented on 2026-05-19
Don't allow all kernels.
8d09cb6a
ebezzam
commented on 2026-05-19
Login to write a write a comment.
Login via GitHub
Reviewers
No reviews
Assignees
No one assigned
Labels
None yet
Milestone
No milestone
Login to write a write a comment.
Login via GitHub