Reducing memory usage: removing useless logits computation in generate() #31292
gante
commented
on 2024-06-18
gante
approved these changes
on 2024-06-20
Add .float() in all generation methods logit outputs
1748ff17
Switch float-casting of logits to training only for main models
3f4f4e8d
Add `num_logits_to_keep` in Llama and add it by default in generate
727c7e46
Apply style
222017d4
Add num_logits_to_keep as arg in prepare_input_for_generation
dc709c68
Add support for Mistral
d2f1566f
Revert models except llama and mistral
f2ef90cd
Fix default None value in _supports_num_logits_to_keep()
ce7b980c
Fix dimension of dummy input
d4201f42
Add exception for prophetnet in _supports_num_logits_to_keep()
b15b5dec
Update _supports_num_logits_to_keep() to use inspect.signature()
95e0807a
Add deprecation cycle + remove modification with pretraining_tp
12db0457
Apply style
b224e24c
Add most used models
f0e1034b
Apply style
9ac57db6
Make `num_logits_to_keep` an int in all cases to remove if-else clause
f7421b69
Add compile check for the warning
c8f91776
Fix torch versions
5e1589e1
style
7998b650
Add gemma2
8fa80181
Update warning version
b49fe767
Add comment about .float operations in generation utils
cf9378a4
Add tests in GenerationTesterMixin and ModelTesterMixin
66e3e9d8
Fix batch size for assisted decoding in tests
e4c5a71b
fix small issues in test
b68ee166
refacor test
e8374252
fix slicing removing dim issue
26863ca3
Add nemotron support (should fix check-copy issue in CIs)
3c3eeaa9
Trigger new CIs
c4008655
Trigger new CIs
802eca83
Bump version
4d6fae65
Bump version in TODO
f12f172f
Trigger CIs
7b1a26cc
remove blank space
b11b048f
Trigger CIs
f03adfb1
gante
merged
22e6f145
into main 1 year ago
Assignees
No one assigned
Login to write a write a comment.
Login via GitHub