transformers
Add chat support to text generation pipeline
#28945
Merged

Add chat support to text generation pipeline #28945

Rocketknight1
Rocketknight11 year ago (edited 1 year ago)

This PR modifies the text generation pipeline to support chats. It does this by inspecting the inputs - if they look like strings, it uses the original causal LM pipeline, and if they look like lists of message dicts, it applies a chat template instead before proceeding with generation.

Most changes are in the preprocessing/postprocessing - the actual generation itself is largely unchanged.

TODO:

  • Expand tests to cover other edge cases
  • Confirm the return format we want for this - just the model response, or the entire chat?
  • Add KV cache support, as this is important for performant multi-turn chat
  • Deprecate ConversationalPipeline and update the chat template docs to refer to this instead?

cc @ArthurZucker @gante @LysandreJik

Rocketknight1 Add chat support to text generation pipeline
e7c4172b
HuggingFaceDocBuilderDev
HuggingFaceDocBuilderDev1 year ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

julien-c
julien-c commented on 2024-02-09
julien-c1 year ago

looks neat from a (very) superficial glance

I think this will be quite useful!

julien-c
julien-c1 year ago

(and yes we should remove the old ConversationalPipeline sooner rather than later given it already doesn't work anymore due to conversational pipeline-type being removed from the Hub, IIUC)

Rocketknight1 Better handling of single elements
a2e190ab
Rocketknight1 Deprecate ConversationalPipeline
a5e1ccc8
Rocketknight1 stash commit
4444a1ee
Rocketknight1 Add missing add_special_tokens kwarg
67726152
Rocketknight1 Update chat templating docs to refer to TextGenerationPipeline instea…
de3d88ab
Rocketknight1
Rocketknight11 year agoπŸŽ‰ 1

@julien-c Done! This PR now adds a DeprecationWarning to ConversationalPipeline. I also updated the chat template docs for the new pipeline.

julien-c
julien-c1 year ago

very nice!

Rocketknight1 Add ✨TF✨ tests
7eb468db
Rocketknight1 @require_tf
6fae42d1
gante
gante approved these changes on 2024-02-14
gante1 year ago

Nice! Thank for adding this 🐈

Conversation is marked as resolved
Show resolved
src/transformers/pipelines/conversational.py
233234
234235 def __init__(self, *args, **kwargs):
236 warnings.warn(
237
"`ConversationalPipeline` is now deprecated, and the functionality has been moved to the standard `text-generation` pipeline, which now accepts lists of message dicts as well as strings. This class will be removed in a future release.",
gante1 year ago

I'd add a specific version in which the class will be deprecated (v4.40, i.e. two minor releases after next minor release?)

Rocketknight11 year ago

Done! I picked 4.42 to give people more time.

Conversation is marked as resolved
Show resolved
src/transformers/pipelines/text_generation.py
25 to this format because the rest of the pipeline code tends to assume that lists of messages are
26 actually a batch of samples rather than messages in the same conversation."""
27
28
def __init__(self, messages):
gante1 year ago
Suggested change
def __init__(self, messages):
def __init__(self, messages: dict):

nit: helps understanding that a dict is expected

Rocketknight11 year ago

Done!

src/transformers/pipelines/text_generation.py
235 if isinstance(text_inputs[0], dict):
236 return super().__call__(Chat(text_inputs), **kwargs)
237 else:
238
chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈
gante1 year ago

best comment πŸ˜‚

Rocketknight1 Add type hint
31640358
Rocketknight1 Add specific deprecation version
01fc1a67
Rocketknight1
Rocketknight11 year ago

One question for people, maybe @gante: Are you okay with the return format I'm using? Right now, if you pass a chat like this:

[ 
    {"role": "system", "content": "This is a system message."},
    {"role": "user", "content": "This is a test"},
]

You get a response that's the same chat, continued:

[
    {"role": "system", "content": "This is a system message."},
    {"role": "user", "content": "This is a test"},
    {"role": "assistant", "content": "This is a reply"},
]

I think this is the right thing to do, because it matches the behaviour of the existing text-generation pipeline (it returns the prompt at the start of the generated string). Let me know if you have a different opinion, though!

gante
gante1 year ago

IMO it looks good to me

Rocketknight1
Rocketknight11 year ago

Cool!

Rocketknight1
Rocketknight11 year agoπŸ‘ 1

In that case, I think we're ready for final review (cc @amyeroberts) - I'm leaving the KV cache to another PR.

Rocketknight1 Rocketknight1 requested a review from amyeroberts amyeroberts 1 year ago
Rocketknight1
Rocketknight11 year ago

cc @LysandreJik @julien-c as well if there's anything else you want me to add before we merge this!

Rocketknight1 Remove unnecessary do_sample
1b3f53f2
Rocketknight1 Remove todo - the discrepancy has been resolved
bbd8cfc8
amyeroberts
amyeroberts approved these changes on 2024-02-15
amyeroberts1 year ago

Beautiful - thanks for adding this support!

Conversation is marked as resolved
Show resolved
src/transformers/tokenization_utils_base.py
17471750 padding = "max_length" # There's only one sequence here, so "longest" makes no sense
17481751 if tokenize:
1749 return self.encode(
1750 rendered,
1751 add_special_tokens=False,
1752 padding=padding,
1753 truncation=truncation,
1754 max_length=max_length,
1755 return_tensors=return_tensors,
1756 **tokenizer_kwargs,
1757 )
1752 if return_dict:
1753
return self(
1754
rendered,
1755
padding=padding,
1756
truncation=truncation,
1757
max_length=max_length,
1758
return_tensors=return_tensors,
1759
add_special_tokens=False,
1760
**tokenizer_kwargs,
1761
)
1762
else:
1763
return self.encode(
1764
rendered,
1765
add_special_tokens=False,
1766
padding=padding,
1767
truncation=truncation,
1768
max_length=max_length,
1769
return_tensors=return_tensors,
1770
**tokenizer_kwargs,
1771
)
amyeroberts1 year ago

nit - let's put the arguments in the same order to make it easier to check

Suggested change
return self(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
add_special_tokens=False,
**tokenizer_kwargs,
)
else:
return self.encode(
rendered,
add_special_tokens=False,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
return self(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
else:
return self.encode(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
Conversation is marked as resolved
Show resolved
src/transformers/pipelines/text_generation.py
230252 **generate_kwargs,
231253 ):
232 inputs = self.tokenizer(
233 prefix + prompt_text,
234 return_tensors=self.framework,
235 truncation=truncation,
236 padding=padding,
237 max_length=max_length,
238 add_special_tokens=add_special_tokens,
239 )
254 if isinstance(prompt_text, Chat):
255
inputs = self.tokenizer.apply_chat_template(
256
prompt_text.messages,
257
padding=padding,
258
add_generation_prompt=True,
259
return_tensors=self.framework,
260
max_length=max_length,
261
truncation=truncation,
262
return_dict=True,
263
)
264
else:
265
inputs = self.tokenizer(
266
prefix + prompt_text,
267
return_tensors=self.framework,
268
truncation=truncation,
269
padding=padding,
270
max_length=max_length,
271
add_special_tokens=add_special_tokens,
272
)
amyeroberts1 year ago

nit - same here (+ my own personal preference for return_tensors being the last option passed)

Suggested change
inputs = self.tokenizer.apply_chat_template(
prompt_text.messages,
padding=padding,
add_generation_prompt=True,
return_tensors=self.framework,
max_length=max_length,
truncation=truncation,
return_dict=True,
)
else:
inputs = self.tokenizer(
prefix + prompt_text,
return_tensors=self.framework,
truncation=truncation,
padding=padding,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
inputs = self.tokenizer.apply_chat_template(
prompt_text.messages,
truncation=truncation,
padding=padding,
max_length=max_length,
add_generation_prompt=True,
return_dict=True,
return_tensors=self.framework,
)
else:
inputs = self.tokenizer(
prefix + prompt_text,
truncation=truncation,
padding=padding,
max_length=max_length,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
)
src/transformers/pipelines/text_generation.py
216230 - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
217231 ids of the generated text.
218232 """
219 return super().__call__(text_inputs, **kwargs)
233
if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)):
amyeroberts1 year ago

Just to make sure - is it not possible for someone to pass this to the pipeline:

# Pass a list-of-list-of-strings
generator([["this is a dog"], ["this is a code example"], ["banana for scale"]])
Rocketknight11 year agoπŸ‘ 1

I tried that on main - it just results in a TypeError: can only concatenate str (not "list") to str. The existing pipeline will only accept either a single string or a non-nested list/tuple of strings, so I don't think this check makes a mistake for any valid inputs!

Rocketknight1 Update src/transformers/tokenization_utils_base.py
ecea9b51
Rocketknight1 Update src/transformers/pipelines/text_generation.py
f9857554
Rocketknight1 Rocketknight1 merged 2f1003be into main 1 year ago
Rocketknight1 Rocketknight1 deleted the support_chat_in_text_gen_pipeline branch 1 year ago

Login to write a write a comment.

Login via GitHub

Assignees
No one assigned
Labels
Milestone