feat: Add GPTChatGenerator to Haystack 2.x (#6212)

* Add GPTChatGenerator

* Apply lessons from previous PR

* PR review - Stefano
This commit is contained in:
Vladimir Blagojevic 2023-11-09 10:45:41 +01:00 committed by GitHub
parent 08e211f9d6
commit cd429a73cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 632 additions and 3 deletions

View File

@ -1,3 +1,4 @@
from haystack.preview.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
from haystack.preview.components.generators.chat.openai import GPTChatGenerator
__all__ = ["HuggingFaceTGIChatGenerator"]
__all__ = ["HuggingFaceTGIChatGenerator", "GPTChatGenerator"]

View File

@ -0,0 +1,286 @@
import dataclasses
import logging
import os
from typing import Optional, List, Callable, Dict, Any
import openai
from openai.openai_object import OpenAIObject
from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.preview.dataclasses import StreamingChunk, ChatMessage
logger = logging.getLogger(__name__)
API_BASE_URL = "https://api.openai.com/v1"
@component
class GPTChatGenerator:
"""
Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo
family of models accessed through the chat completions API endpoint.
Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method
directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs`
parameter in `run` method.
For more details on the parameters supported by the OpenAI API, refer to the OpenAI
[documentation](https://platform.openai.com/docs/api-reference/chat).
```python
from haystack.preview.components.generators.chat import GPTChatGenerator
from haystack.preview.dataclasses import ChatMessage
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
client = GPTChatGenerator()
response = client.run(messages)
print(response)
>>{'replies': [ChatMessage(content='Natural Language Processing (NLP) is a branch of artificial intelligence
>>that focuses on enabling computers to understand, interpret, and generate human language in a way that is
>>meaningful and useful.', role=<ChatRole.ASSISTANT: 'assistant'>, name=None,
>>metadata={'model': 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop',
>>'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]}
```
Key Features and Compatibility:
- **Primary Compatibility**: Designed to work seamlessly with the OpenAI API Chat Completion endpoint
and gpt-4 and gpt-3.5-turbo family of models.
- **Streaming Support**: Supports streaming responses from the OpenAI API Chat Completion endpoint.
- **Customizability**: Supports all parameters supported by the OpenAI API Chat Completion endpoint.
Input and Output Format:
- **ChatMessage Format**: This component uses the ChatMessage format for structuring both input and output,
ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the
ChatMessage format can be found at: https://github.com/openai/openai-python/blob/main/chatml.md.
"""
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: str = API_BASE_URL,
**generation_kwargs,
):
"""
Creates an instance of ChatGPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
more details.
Some of the supported parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
"""
# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or openai.api_key
if api_key is None:
try:
api_key = os.environ["OPENAI_API_KEY"]
except KeyError as e:
raise ValueError(
"GPTChatGenerator expects an OpenAI API key. "
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
) from e
openai.api_key = api_key
self.model_name = model_name
self.generation_kwargs = generation_kwargs
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
openai.api_base = api_base_url
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
"""
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model_name=self.model_name,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.generation_kwargs,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GPTChatGenerator":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
return default_from_dict(cls, data)
@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
:param messages: A list of ChatMessage instances representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation. These parameters will
potentially override the parameters passed in the __init__ method.
For more details on the parameters supported by the OpenAI API, refer to the
OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create).
:return: A list containing the generated responses as ChatMessage instances.
"""
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = self._convert_to_openai_format(messages)
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=openai_formatted_messages,
stream=self.streaming_callback is not None,
**generation_kwargs,
)
completions: List[ChatMessage]
if self.streaming_callback:
num_responses = generation_kwargs.pop("n", 1)
if num_responses > 1:
raise ValueError("Cannot stream multiple responses, please set n=1.")
chunks: List[StreamingChunk] = []
chunk = None
for chunk in completion:
if chunk.choices:
chunk_delta: StreamingChunk = self._build_chunk(chunk, chunk.choices[0])
chunks.append(chunk_delta)
self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)]
else:
completions = [self._build_message(completion, choice) for choice in completion.choices]
# before returning, do post-processing of the completions
for message in completions:
self._check_finish_reason(message)
return {"replies": completions}
def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""
Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API.
:param messages: The list of ChatMessage.
:return: The list of messages in the format expected by the OpenAI API.
"""
openai_chat_message_format = {"role", "content", "name"}
openai_formatted_messages = []
for m in messages:
message_dict = dataclasses.asdict(m)
filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v}
openai_formatted_messages.append(filtered_message)
return openai_formatted_messages
def _connect_chunks(self, chunk: OpenAIObject, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
:param chunk: The last chunk returned by the OpenAI API.
:param chunks: The list of all chunks returned by the OpenAI API.
"""
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks]))
complete_response.metadata.update(
{
"model": chunk.model,
"index": 0,
"finish_reason": chunk.choices[0].finish_reason,
"usage": {}, # we don't have usage data for streaming responses
}
)
return complete_response
def _build_message(self, completion: OpenAIObject, choice: OpenAIObject) -> ChatMessage:
"""
Converts the non-streaming response from the OpenAI API to a ChatMessage.
:param completion: The completion returned by the OpenAI API.
:param choice: The choice returned by the OpenAI API.
:return: The ChatMessage.
"""
message: OpenAIObject = choice.message
content = message.function_call if choice.finish_reason == "function_call" else message.content
chat_message = ChatMessage.from_assistant(content)
chat_message.metadata.update(
{
"model": completion.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"usage": dict(completion.usage.items()),
}
)
return chat_message
def _build_chunk(self, chunk: OpenAIObject, choice: OpenAIObject) -> StreamingChunk:
"""
Converts the streaming response chunk from the OpenAI API to a StreamingChunk.
:param chunk: The chunk returned by the OpenAI API.
:param choice: The choice returned by the OpenAI API.
:return: The StreamingChunk.
"""
has_content = bool(hasattr(choice.delta, "content") and choice.delta.content)
if has_content:
content = choice.delta.content
elif hasattr(choice.delta, "function_call"):
content = choice.delta.function_call
else:
content = ""
chunk_message = StreamingChunk(content)
chunk_message.metadata.update(
{"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason}
)
return chunk_message
def _check_finish_reason(self, message: ChatMessage) -> None:
"""
Check the `finish_reason` returned with the OpenAI completions.
If the `finish_reason` is `length` or `content_filter`, log a warning.
:param message: The message returned by the LLM.
"""
if message.metadata["finish_reason"] == "length":
logger.warning(
"The completion for index %s has been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions.",
message.metadata["index"],
)
if message.metadata["finish_reason"] == "content_filter":
logger.warning(
"The completion for index %s has been truncated due to the content filter.", message.metadata["index"]
)

View File

@ -38,14 +38,15 @@ class ChatMessage:
return self.role == role
@classmethod
def from_assistant(cls, content: str) -> "ChatMessage":
def from_assistant(cls, content: str, metadata: Optional[Dict[str, Any]] = None) -> "ChatMessage":
"""
Create a message from the assistant.
:param content: The text content of the message.
:param metadata: Additional metadata associated with the message.
:return: A new ChatMessage instance.
"""
return cls(content, ChatRole.ASSISTANT, None)
return cls(content, ChatRole.ASSISTANT, None, metadata or {})
@classmethod
def from_user(cls, content: str) -> "ChatMessage":

View File

@ -0,0 +1,4 @@
---
preview:
- |
Adds GPTChatGenerator, a chat-based OpenAI LLM component, ChatMessage(s) are used for input and output

View File

@ -0,0 +1,336 @@
import os
from unittest.mock import patch, Mock
import openai
import pytest
from haystack.preview.components.generators.chat import GPTChatGenerator
from haystack.preview.components.generators.utils import default_streaming_callback
from haystack.preview.dataclasses import ChatMessage, StreamingChunk
@pytest.fixture
def mock_chat_completion():
"""
Mock the OpenAI API completion response and reuse it for tests
"""
with patch("openai.ChatCompletion.create", autospec=True) as mock_chat_completion_create:
# mimic the response from the OpenAI API
mock_choice = Mock()
mock_choice.index = 0
mock_choice.finish_reason = "stop"
mock_message = Mock()
mock_message.content = "I'm fine, thanks. How are you?"
mock_message.role = "user"
mock_choice.message = mock_message
mock_response = Mock()
mock_response.model = "gpt-3.5-turbo"
mock_response.usage = Mock()
mock_response.usage.items.return_value = [
("prompt_tokens", 57),
("completion_tokens", 40),
("total_tokens", 97),
]
mock_response.choices = [mock_choice]
mock_chat_completion_create.return_value = mock_response
yield mock_chat_completion_create
def streaming_chunk(content: str):
"""
Mock chunks of streaming responses from the OpenAI API
"""
# mimic the chunk response from the OpenAI API
mock_choice = Mock()
mock_choice.index = 0
mock_choice.delta.content = content
mock_choice.finish_reason = "stop"
mock_response = Mock()
mock_response.choices = [mock_choice]
mock_response.model = "gpt-3.5-turbo"
mock_response.usage = Mock()
mock_response.usage.items.return_value = [("prompt_tokens", 57), ("completion_tokens", 40), ("total_tokens", 97)]
return mock_response
@pytest.fixture
def chat_messages():
return [
ChatMessage.from_system("You are a helpful assistant"),
ChatMessage.from_user("What's the capital of France"),
]
class TestGPTChatGenerator:
@pytest.mark.unit
def test_init_default(self):
component = GPTChatGenerator(api_key="test-api-key")
assert openai.api_key == "test-api-key"
assert component.model_name == "gpt-3.5-turbo"
assert component.streaming_callback is None
assert component.api_base_url == "https://api.openai.com/v1"
assert openai.api_base == "https://api.openai.com/v1"
assert not component.generation_kwargs
@pytest.mark.unit
def test_init_fail_wo_api_key(self, monkeypatch):
openai.api_key = None
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(ValueError, match="GPTChatGenerator expects an OpenAI API key"):
GPTChatGenerator()
@pytest.mark.unit
def test_init_with_parameters(self):
component = GPTChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
)
assert openai.api_key == "test-api-key"
assert component.model_name == "gpt-4"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
assert openai.api_base == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
@pytest.mark.unit
def test_to_dict_default(self):
component = GPTChatGenerator(api_key="test-api-key")
data = component.to_dict()
assert data == {
"type": "GPTChatGenerator",
"init_parameters": {
"model_name": "gpt-3.5-turbo",
"streaming_callback": None,
"api_base_url": "https://api.openai.com/v1",
},
}
@pytest.mark.unit
def test_to_dict_with_parameters(self):
component = GPTChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "GPTChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
},
}
@pytest.mark.unit
def test_to_dict_with_lambda_streaming_callback(self):
component = GPTChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
max_tokens=10,
some_test_param="test-params",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "GPTChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "chat.test_openai.<lambda>",
},
}
@pytest.mark.unit
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
data = {
"type": "GPTChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
},
}
component = GPTChatGenerator.from_dict(data)
assert component.model_name == "gpt-4"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
@pytest.mark.unit
def test_from_dict_fail_wo_env_var(self, monkeypatch):
openai.api_key = None
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "GPTChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
},
}
with pytest.raises(ValueError, match="GPTChatGenerator expects an OpenAI API key"):
GPTChatGenerator.from_dict(data)
@pytest.mark.unit
def test_run(self, chat_messages, mock_chat_completion):
component = GPTChatGenerator(api_key="test-api-key")
response = component.run(chat_messages)
# check that the component returns the correct ChatMessage response
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.unit
def test_run_with_params(self, chat_messages, mock_chat_completion):
component = GPTChatGenerator(api_key="test-api-key", max_tokens=10, temperature=0.5)
response = component.run(chat_messages)
# check that the component calls the OpenAI API with the correct parameters
_, kwargs = mock_chat_completion.call_args
assert kwargs["max_tokens"] == 10
assert kwargs["temperature"] == 0.5
# check that the component returns the correct response
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.unit
def test_run_streaming(self, chat_messages, mock_chat_completion):
streaming_call_count = 0
# Define the streaming callback function and assert that it is called with StreamingChunk objects
def streaming_callback_fn(chunk: StreamingChunk):
nonlocal streaming_call_count
streaming_call_count += 1
assert isinstance(chunk, StreamingChunk)
generator = GPTChatGenerator(api_key="test-api-key", streaming_callback=streaming_callback_fn)
# Create a fake streamed response
# self needed here, don't remove
def mock_iter(self):
yield streaming_chunk("Hello")
yield streaming_chunk("How are you?")
mock_response = Mock(**{"__iter__": mock_iter})
mock_chat_completion.return_value = mock_response
response = generator.run(chat_messages)
# Assert that the streaming callback was called twice
assert streaming_call_count == 2
# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
@pytest.mark.unit
def test_check_abnormal_completions(self, caplog):
component = GPTChatGenerator(api_key="test-api-key")
messages = [
ChatMessage.from_assistant(
"", metadata={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
)
for i, _ in enumerate(range(4))
]
for m in messages:
component._check_finish_reason(m)
# check truncation warning
message_template = (
"The completion for index {index} has been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions."
)
for index in [1, 3]:
assert caplog.records[index].message == message_template.format(index=index)
# check content filter warning
message_template = "The completion for index {index} has been truncated due to the content filter."
for index in [0, 2]:
assert caplog.records[index].message == message_template.format(index=index)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run(self):
chat_messages = [ChatMessage.from_user("What's the capital of France")]
component = GPTChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "gpt-3.5" in message.metadata["model"]
assert message.metadata["finish_reason"] == "stop"
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages):
component = GPTChatGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"):
component.run(chat_messages)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_streaming(self):
class Callback:
def __init__(self):
self.responses = ""
self.counter = 0
def __call__(self, chunk: StreamingChunk) -> None:
self.counter += 1
self.responses += chunk.content if chunk.content else ""
callback = Callback()
component = GPTChatGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
results = component.run([ChatMessage.from_user("What's the capital of France?")])
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content
assert "gpt-3.5" in message.metadata["model"]
assert message.metadata["finish_reason"] == "stop"
assert callback.counter > 1
assert "Paris" in callback.responses

View File

@ -202,6 +202,7 @@ class TestGPTGenerator:
assert len(response["replies"]) == 1
assert [isinstance(reply, str) for reply in response["replies"]]
@pytest.mark.unit
def test_run_with_params(self, mock_chat_completion):
component = GPTGenerator(api_key="test-api-key", max_tokens=10, temperature=0.5)
response = component.run("What's Natural Language Processing?")