refactor: pass a role string to OpenAI API (#7404)

* draft

* rm unused imports
This commit is contained in:
Stefano Fiorucci 2024-03-22 09:36:56 +01:00 committed by GitHub
parent e779d43384
commit c789f905bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 39 deletions

View File

@ -228,8 +228,9 @@ class HuggingFaceTGIChatGenerator:
raise RuntimeError("Please call warm_up() before running LLM inference.")
# apply either model's chat template or the user-provided one
formatted_messages = [message.to_openai_format() for message in messages]
prepared_prompt: str = self.tokenizer.apply_chat_template(
conversation=messages, chat_template=self.chat_template, tokenize=False
conversation=formatted_messages, chat_template=self.chat_template, tokenize=False
)
prompt_token_count: int = len(self.tokenizer.encode(prepared_prompt, add_special_tokens=False))

View File

@ -1,9 +1,8 @@
import copy
import dataclasses
import json
from typing import Any, Callable, Dict, List, Optional, Union
from openai import OpenAI, Stream # type: ignore
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
@ -169,7 +168,7 @@ class OpenAIChatGenerator:
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = self._convert_to_openai_format(messages)
openai_formatted_messages = [message.to_openai_format() for message in messages]
chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model,
@ -204,20 +203,6 @@ class OpenAIChatGenerator:
return {"replies": completions}
def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""
Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API.
:param messages: The list of ChatMessage.
:return: The list of messages in the format expected by the OpenAI API.
"""
openai_chat_message_format = {"role", "content", "name"}
openai_formatted_messages = []
for m in messages:
message_dict = dataclasses.asdict(m)
filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v}
openai_formatted_messages.append(filtered_message)
return openai_formatted_messages
def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.

View File

@ -1,4 +1,3 @@
import dataclasses
from typing import Any, Callable, Dict, List, Optional, Union
from openai import OpenAI, Stream
@ -164,7 +163,7 @@ class OpenAIGenerator:
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = self._convert_to_openai_format(messages)
openai_formatted_messages = [message.to_openai_format() for message in messages]
completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model,
@ -200,23 +199,6 @@ class OpenAIGenerator:
"meta": [message.meta for message in completions],
}
def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""
Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API.
:param messages:
The list of ChatMessage.
:returns:
The list of messages in the format expected by the OpenAI API.
"""
openai_chat_message_format = {"role", "content", "name"}
openai_formatted_messages = []
for m in messages:
message_dict = dataclasses.asdict(m)
filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v}
openai_formatted_messages.append(filtered_message)
return openai_formatted_messages
def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.

View File

@ -28,6 +28,22 @@ class ChatMessage:
name: Optional[str]
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
def to_openai_format(self) -> Dict[str, Any]:
"""
Convert the message to the format expected by OpenAI's Chat API.
See the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for details.
:returns: A dictionary with the following key:
- `role`
- `content`
- `name` (optional)
"""
msg = {"role": self.role.value, "content": self.content}
if self.name:
msg["name"] = self.name
return msg
def is_from(self, role: ChatRole) -> bool:
"""
Check if the message is from a specific role.

View File

@ -37,11 +37,23 @@ def test_from_function_with_empty_name():
assert message.name == ""
def test_to_openai_format():
message = ChatMessage.from_system("You are good assistant")
assert message.to_openai_format() == {"role": "system", "content": "You are good assistant"}
message = ChatMessage.from_user("I have a question")
assert message.to_openai_format() == {"role": "user", "content": "I have a question"}
message = ChatMessage.from_function("Function call", "function_name")
assert message.to_openai_format() == {"role": "function", "content": "Function call", "name": "function_name"}
@pytest.mark.integration
def test_apply_chat_templating_on_chat_message():
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenized_messages = tokenizer.apply_chat_template(messages, tokenize=False)
formatted_messages = [m.to_openai_format() for m in messages]
tokenized_messages = tokenizer.apply_chat_template(formatted_messages, tokenize=False)
assert tokenized_messages == "<|system|>\nYou are good assistant</s>\n<|user|>\nI have a question</s>\n"
@ -61,5 +73,8 @@ def test_apply_custom_chat_templating_on_chat_message():
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
# could be any tokenizer, let's use the one we already likely have in cache
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenized_messages = tokenizer.apply_chat_template(messages, chat_template=anthropic_template, tokenize=False)
formatted_messages = [m.to_openai_format() for m in messages]
tokenized_messages = tokenizer.apply_chat_template(
formatted_messages, chat_template=anthropic_template, tokenize=False
)
assert tokenized_messages == "You are good assistant\nHuman: I have a question\nAssistant:"