mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 22:48:29 +00:00
refactor: pass a role string to OpenAI API (#7404)
* draft * rm unused imports
This commit is contained in:
parent
e779d43384
commit
c789f905bc
@ -228,8 +228,9 @@ class HuggingFaceTGIChatGenerator:
|
||||
raise RuntimeError("Please call warm_up() before running LLM inference.")
|
||||
|
||||
# apply either model's chat template or the user-provided one
|
||||
formatted_messages = [message.to_openai_format() for message in messages]
|
||||
prepared_prompt: str = self.tokenizer.apply_chat_template(
|
||||
conversation=messages, chat_template=self.chat_template, tokenize=False
|
||||
conversation=formatted_messages, chat_template=self.chat_template, tokenize=False
|
||||
)
|
||||
prompt_token_count: int = len(self.tokenizer.encode(prepared_prompt, add_special_tokens=False))
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from openai import OpenAI, Stream # type: ignore
|
||||
from openai import OpenAI, Stream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
@ -169,7 +168,7 @@ class OpenAIChatGenerator:
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
# adapt ChatMessage(s) to the format expected by the OpenAI API
|
||||
openai_formatted_messages = self._convert_to_openai_format(messages)
|
||||
openai_formatted_messages = [message.to_openai_format() for message in messages]
|
||||
|
||||
chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
@ -204,20 +203,6 @@ class OpenAIChatGenerator:
|
||||
|
||||
return {"replies": completions}
|
||||
|
||||
def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API.
|
||||
:param messages: The list of ChatMessage.
|
||||
:return: The list of messages in the format expected by the OpenAI API.
|
||||
"""
|
||||
openai_chat_message_format = {"role", "content", "name"}
|
||||
openai_formatted_messages = []
|
||||
for m in messages:
|
||||
message_dict = dataclasses.asdict(m)
|
||||
filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v}
|
||||
openai_formatted_messages.append(filtered_message)
|
||||
return openai_formatted_messages
|
||||
|
||||
def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
|
||||
"""
|
||||
Connects the streaming chunks into a single ChatMessage.
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
@ -164,7 +163,7 @@ class OpenAIGenerator:
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
# adapt ChatMessage(s) to the format expected by the OpenAI API
|
||||
openai_formatted_messages = self._convert_to_openai_format(messages)
|
||||
openai_formatted_messages = [message.to_openai_format() for message in messages]
|
||||
|
||||
completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
@ -200,23 +199,6 @@ class OpenAIGenerator:
|
||||
"meta": [message.meta for message in completions],
|
||||
}
|
||||
|
||||
def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API.
|
||||
|
||||
:param messages:
|
||||
The list of ChatMessage.
|
||||
:returns:
|
||||
The list of messages in the format expected by the OpenAI API.
|
||||
"""
|
||||
openai_chat_message_format = {"role", "content", "name"}
|
||||
openai_formatted_messages = []
|
||||
for m in messages:
|
||||
message_dict = dataclasses.asdict(m)
|
||||
filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v}
|
||||
openai_formatted_messages.append(filtered_message)
|
||||
return openai_formatted_messages
|
||||
|
||||
def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
|
||||
"""
|
||||
Connects the streaming chunks into a single ChatMessage.
|
||||
|
||||
@ -28,6 +28,22 @@ class ChatMessage:
|
||||
name: Optional[str]
|
||||
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
|
||||
|
||||
def to_openai_format(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert the message to the format expected by OpenAI's Chat API.
|
||||
See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details.
|
||||
|
||||
:returns: A dictionary with the following key:
|
||||
- `role`
|
||||
- `content`
|
||||
- `name` (optional)
|
||||
"""
|
||||
msg = {"role": self.role.value, "content": self.content}
|
||||
if self.name:
|
||||
msg["name"] = self.name
|
||||
|
||||
return msg
|
||||
|
||||
def is_from(self, role: ChatRole) -> bool:
|
||||
"""
|
||||
Check if the message is from a specific role.
|
||||
|
||||
@ -37,11 +37,23 @@ def test_from_function_with_empty_name():
|
||||
assert message.name == ""
|
||||
|
||||
|
||||
def test_to_openai_format():
|
||||
message = ChatMessage.from_system("You are good assistant")
|
||||
assert message.to_openai_format() == {"role": "system", "content": "You are good assistant"}
|
||||
|
||||
message = ChatMessage.from_user("I have a question")
|
||||
assert message.to_openai_format() == {"role": "user", "content": "I have a question"}
|
||||
|
||||
message = ChatMessage.from_function("Function call", "function_name")
|
||||
assert message.to_openai_format() == {"role": "function", "content": "Function call", "name": "function_name"}
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_apply_chat_templating_on_chat_message():
|
||||
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
tokenized_messages = tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
formatted_messages = [m.to_openai_format() for m in messages]
|
||||
tokenized_messages = tokenizer.apply_chat_template(formatted_messages, tokenize=False)
|
||||
assert tokenized_messages == "<|system|>\nYou are good assistant</s>\n<|user|>\nI have a question</s>\n"
|
||||
|
||||
|
||||
@ -61,5 +73,8 @@ def test_apply_custom_chat_templating_on_chat_message():
|
||||
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
|
||||
# could be any tokenizer, let's use the one we already likely have in cache
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
tokenized_messages = tokenizer.apply_chat_template(messages, chat_template=anthropic_template, tokenize=False)
|
||||
formatted_messages = [m.to_openai_format() for m in messages]
|
||||
tokenized_messages = tokenizer.apply_chat_template(
|
||||
formatted_messages, chat_template=anthropic_template, tokenize=False
|
||||
)
|
||||
assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user