diff --git a/haystack/components/generators/chat/hugging_face_tgi.py b/haystack/components/generators/chat/hugging_face_tgi.py index 95adbe792..ff6c119f4 100644 --- a/haystack/components/generators/chat/hugging_face_tgi.py +++ b/haystack/components/generators/chat/hugging_face_tgi.py @@ -228,8 +228,9 @@ class HuggingFaceTGIChatGenerator: raise RuntimeError("Please call warm_up() before running LLM inference.") # apply either model's chat template or the user-provided one + formatted_messages = [message.to_openai_format() for message in messages] prepared_prompt: str = self.tokenizer.apply_chat_template( - conversation=messages, chat_template=self.chat_template, tokenize=False + conversation=formatted_messages, chat_template=self.chat_template, tokenize=False ) prompt_token_count: int = len(self.tokenizer.encode(prepared_prompt, add_special_tokens=False)) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index e3ce9e1c0..d05ed1d8b 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -1,9 +1,8 @@ import copy -import dataclasses import json from typing import Any, Callable, Dict, List, Optional, Union -from openai import OpenAI, Stream # type: ignore +from openai import OpenAI, Stream from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice @@ -169,7 +168,7 @@ class OpenAIChatGenerator: generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} # adapt ChatMessage(s) to the format expected by the OpenAI API - openai_formatted_messages = self._convert_to_openai_format(messages) + openai_formatted_messages = [message.to_openai_format() for message in messages] chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create( model=self.model, @@ -204,20 +203,6 @@ class OpenAIChatGenerator: return {"replies": completions} - def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: - """ - Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API. - :param messages: The list of ChatMessage. - :return: The list of messages in the format expected by the OpenAI API. - """ - openai_chat_message_format = {"role", "content", "name"} - openai_formatted_messages = [] - for m in messages: - message_dict = dataclasses.asdict(m) - filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v} - openai_formatted_messages.append(filtered_message) - return openai_formatted_messages - def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage: """ Connects the streaming chunks into a single ChatMessage. diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index f2f374e3d..966b552bd 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -1,4 +1,3 @@ -import dataclasses from typing import Any, Callable, Dict, List, Optional, Union from openai import OpenAI, Stream @@ -164,7 +163,7 @@ class OpenAIGenerator: generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} # adapt ChatMessage(s) to the format expected by the OpenAI API - openai_formatted_messages = self._convert_to_openai_format(messages) + openai_formatted_messages = [message.to_openai_format() for message in messages] completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create( model=self.model, @@ -200,23 +199,6 @@ class OpenAIGenerator: "meta": [message.meta for message in completions], } - def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: - """ - Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API. - - :param messages: - The list of ChatMessage. - :returns: - The list of messages in the format expected by the OpenAI API. - """ - openai_chat_message_format = {"role", "content", "name"} - openai_formatted_messages = [] - for m in messages: - message_dict = dataclasses.asdict(m) - filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v} - openai_formatted_messages.append(filtered_message) - return openai_formatted_messages - def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage: """ Connects the streaming chunks into a single ChatMessage. diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index e9a9b0a15..941c92e99 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -28,6 +28,22 @@ class ChatMessage: name: Optional[str] meta: Dict[str, Any] = field(default_factory=dict, hash=False) + def to_openai_format(self) -> Dict[str, Any]: + """ + Convert the message to the format expected by OpenAI's Chat API. + See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details. + + :returns: A dictionary with the following key: + - `role` + - `content` + - `name` (optional) + """ + msg = {"role": self.role.value, "content": self.content} + if self.name: + msg["name"] = self.name + + return msg + def is_from(self, role: ChatRole) -> bool: """ Check if the message is from a specific role. diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 22e193660..1a8c6ad7b 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -37,11 +37,23 @@ def test_from_function_with_empty_name(): assert message.name == "" +def test_to_openai_format(): + message = ChatMessage.from_system("You are good assistant") + assert message.to_openai_format() == {"role": "system", "content": "You are good assistant"} + + message = ChatMessage.from_user("I have a question") + assert message.to_openai_format() == {"role": "user", "content": "I have a question"} + + message = ChatMessage.from_function("Function call", "function_name") + assert message.to_openai_format() == {"role": "function", "content": "Function call", "name": "function_name"} + + @pytest.mark.integration def test_apply_chat_templating_on_chat_message(): messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")] tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - tokenized_messages = tokenizer.apply_chat_template(messages, tokenize=False) + formatted_messages = [m.to_openai_format() for m in messages] + tokenized_messages = tokenizer.apply_chat_template(formatted_messages, tokenize=False) assert tokenized_messages == "<|system|>\nYou are good assistant\n<|user|>\nI have a question\n" @@ -61,5 +73,8 @@ def test_apply_custom_chat_templating_on_chat_message(): messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")] # could be any tokenizer, let's use the one we already likely have in cache tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") - tokenized_messages = tokenizer.apply_chat_template(messages, chat_template=anthropic_template, tokenize=False) + formatted_messages = [m.to_openai_format() for m in messages] + tokenized_messages = tokenizer.apply_chat_template( + formatted_messages, chat_template=anthropic_template, tokenize=False + ) assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"