feat: Adapt GPTGenerator to use str input/output format in Haystack 2.x (#6214)

* Adapt GPTGenerator to string input/output

* Finishing touches

* punctuation upd

* PR feedback

* Small naming fixes

* Update haystack/preview/components/generators/openai.py

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>

* Update class pydoc with a printed response

---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
Vladimir Blagojevic 2023-11-07 18:00:43 +01:00 committed by GitHub
parent 6c5bfe3da4
commit 5497ca2a45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 655 additions and 564 deletions

View File

@ -7,7 +7,7 @@ from haystack.preview.document_stores import InMemoryDocumentStore
from haystack.preview.components.writers import DocumentWriter
from haystack.preview.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.preview.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.preview.components.generators.openai.gpt import GPTGenerator
from haystack.preview.components.generators import GPTGenerator
from haystack.preview.components.builders.answer_builder import AnswerBuilder
from haystack.preview.components.builders.prompt_builder import PromptBuilder

View File

@ -1,5 +1,5 @@
from haystack.preview.components.generators.openai.gpt import GPTGenerator
from haystack.preview.components.generators.hugging_face.hugging_face_local import HuggingFaceLocalGenerator
from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
from haystack.preview.components.generators.openai import GPTGenerator
__all__ = ["GPTGenerator", "HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator"]
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator"]

View File

@ -0,0 +1,290 @@
import dataclasses
import logging
import os
from typing import Optional, List, Callable, Dict, Any
import openai
from openai.openai_object import OpenAIObject
from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.preview.dataclasses import StreamingChunk, ChatMessage
logger = logging.getLogger(__name__)
API_BASE_URL = "https://api.openai.com/v1"
@component
class GPTGenerator:
"""
Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo
family of models.
Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method
directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs`
parameter in `run` method.
For more details on the parameters supported by the OpenAI API, refer to the OpenAI
[documentation](https://platform.openai.com/docs/api-reference/chat).
```python
from haystack.preview.components.generators import GPTGenerator
client = GPTGenerator()
response = client.run("What's Natural Language Processing? Be brief.")
print(response)
>> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
>> the interaction between computers and human language. It involves enabling computers to understand, interpret,
>> and respond to natural human language in a way that is both meaningful and useful.'], 'metadata': [{'model':
>> 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 16,
>> 'completion_tokens': 49, 'total_tokens': 65}}]}
```
Key Features and Compatibility:
- **Primary Compatibility**: Designed to work seamlessly with gpt-4, gpt-3.5-turbo family of models.
- **Streaming Support**: Supports streaming responses from the OpenAI API.
- **Customizability**: Supports all parameters supported by the OpenAI API.
Input and Output Format:
- **String Format**: This component uses the strings for both input and output.
"""
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: str = API_BASE_URL,
system_prompt: Optional[str] = None,
**generation_kwargs,
):
"""
Creates an instance of ChatGPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param system_prompt: The system prompt to use for text generation. If not provided, the system prompt is
omitted, and the default system prompt of the model is used.
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
more details.
Some of the supported parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So, 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
"""
# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or openai.api_key
if api_key is None:
try:
api_key = os.environ["OPENAI_API_KEY"]
except KeyError as e:
raise ValueError(
"GPTGenerator expects an OpenAI API key. "
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
) from e
openai.api_key = api_key
self.model_name = model_name
self.generation_kwargs = generation_kwargs
self.system_prompt = system_prompt
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
openai.api_base = api_base_url
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
"""
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model_name=self.model_name,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.generation_kwargs,
system_prompt=self.system_prompt,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
return default_from_dict(cls, data)
@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Invoke the text generation inference based on the provided messages and generation parameters.
:param prompt: The string prompt to use for text generation.
:param generation_kwargs: Additional keyword arguments for text generation. These parameters will
potentially override the parameters passed in the __init__ method.
For more details on the parameters supported by the OpenAI API, refer to the
OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create).
:return: A list of strings containing the generated responses and a list of dictionaries containing the metadata
for each response.
"""
message = ChatMessage.from_user(prompt)
if self.system_prompt:
messages = [ChatMessage.from_system(self.system_prompt), message]
else:
messages = [message]
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = self._convert_to_openai_format(messages)
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=openai_formatted_messages,
stream=self.streaming_callback is not None,
**generation_kwargs,
)
completions: List[ChatMessage]
if self.streaming_callback:
num_responses = generation_kwargs.pop("n", 1)
if num_responses > 1:
raise ValueError("Cannot stream multiple responses, please set n=1.")
chunks: List[StreamingChunk] = []
chunk = None
for chunk in completion:
if chunk.choices:
chunk_delta: StreamingChunk = self._build_chunk(chunk, chunk.choices[0])
chunks.append(chunk_delta)
self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)]
else:
completions = [self._build_message(completion, choice) for choice in completion.choices]
# before returning, do post-processing of the completions
for completion in completions:
self._check_finish_reason(completion)
return {
"replies": [message.content for message in completions],
"metadata": [message.metadata for message in completions],
}
def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""
Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API.
:param messages: The list of ChatMessage.
:return: The list of messages in the format expected by the OpenAI API.
"""
openai_chat_message_format = {"role", "content", "name"}
openai_formatted_messages = []
for m in messages:
message_dict = dataclasses.asdict(m)
filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v}
openai_formatted_messages.append(filtered_message)
return openai_formatted_messages
def _connect_chunks(self, chunk: OpenAIObject, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
"""
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks]))
complete_response.metadata.update(
{
"model": chunk.model,
"index": 0,
"finish_reason": chunk.choices[0].finish_reason,
"usage": {}, # we don't have usage data for streaming responses
}
)
return complete_response
def _build_message(self, completion: OpenAIObject, choice: OpenAIObject) -> ChatMessage:
"""
Converts the response from the OpenAI API to a ChatMessage.
:param completion: The completion returned by the OpenAI API.
:param choice: The choice returned by the OpenAI API.
:return: The ChatMessage.
"""
message: OpenAIObject = choice.message
content = dict(message.function_call) if choice.finish_reason == "function_call" else message.content
chat_message = ChatMessage.from_assistant(content)
chat_message.metadata.update(
{
"model": completion.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"usage": dict(completion.usage.items()),
}
)
return chat_message
def _build_chunk(self, chunk: OpenAIObject, choice: OpenAIObject) -> StreamingChunk:
"""
Converts the response from the OpenAI API to a StreamingChunk.
:param chunk: The chunk returned by the OpenAI API.
:param choice: The choice returned by the OpenAI API.
:return: The StreamingChunk.
"""
has_content = bool(hasattr(choice.delta, "content") and choice.delta.content)
if has_content:
content = choice.delta.content
elif hasattr(choice.delta, "function_call"):
content = str(choice.delta.function_call)
else:
content = ""
chunk_message = StreamingChunk(content)
chunk_message.metadata.update(
{"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason}
)
return chunk_message
def _check_finish_reason(self, message: ChatMessage) -> None:
"""
Check the `finish_reason` returned with the OpenAI completions.
If the `finish_reason` is `length`, log a warning to the user.
:param message: The message returned by the LLM.
"""
if message.metadata["finish_reason"] == "length":
logger.warning(
"The completion for index %s has been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions.",
message.metadata["index"],
)
if message.metadata["finish_reason"] == "content_filter":
logger.warning(
"The completion for index %s has been truncated due to the content filter.", message.metadata["index"]
)

View File

@ -1,225 +0,0 @@
from typing import Optional, List, Callable, Dict, Any
import sys
import logging
from collections import defaultdict
from dataclasses import dataclass, asdict
import os
import openai
from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError
logger = logging.getLogger(__name__)
API_BASE_URL = "https://api.openai.com/v1"
@dataclass
class _ChatMessage:
content: str
role: str
def default_streaming_callback(chunk):
"""
Default callback function for streaming responses from OpenAI API.
Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged.
"""
if hasattr(chunk.choices[0].delta, "content"):
print(chunk.choices[0].delta.content, flush=True, end="")
return chunk
@component
class GPTGenerator:
"""
LLM Generator compatible with GPT (ChatGPT) large language models.
Queries the LLM using OpenAI's API. Invocations are made using OpenAI SDK ('openai' package)
See [OpenAI GPT API](https://platform.openai.com/docs/guides/chat) for more details.
"""
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
system_prompt: Optional[str] = None,
streaming_callback: Optional[Callable] = None,
api_base_url: str = API_BASE_URL,
**kwargs,
):
"""
Creates an instance of GPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param model_name: The name of the model to use.
:param system_prompt: An additional message to be sent to the LLM at the beginning of each conversation.
Typically, a conversation is formatted with a system message first, followed by alternating messages from
the 'user' (the "queries") and the 'assistant' (the "responses"). The system message helps set the behavior
of the assistant. For example, you can modify the personality of the assistant or provide specific
instructions about how it should behave throughout the conversation.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
provided, the token is printed to stdout.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI
endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details.
Some of the supported parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
values are the bias to add to that token.
"""
# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or openai.api_key
if api_key is None:
try:
api_key = os.environ["OPENAI_API_KEY"]
except KeyError as e:
raise ValueError(
"GPTGenerator expects an OpenAI API key. "
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
) from e
openai.api_key = api_key
self.model_name = model_name
self.system_prompt = system_prompt
self.model_parameters = kwargs
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
openai.api_base = api_base_url
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
if self.streaming_callback:
module = self.streaming_callback.__module__
if module == "builtins":
callback_name = self.streaming_callback.__name__
else:
callback_name = f"{module}.{self.streaming_callback.__name__}"
else:
callback_name = None
return default_to_dict(
self,
model_name=self.model_name,
system_prompt=self.system_prompt,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.model_parameters,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
streaming_callback = None
if "streaming_callback" in init_params and init_params["streaming_callback"]:
parts = init_params["streaming_callback"].split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}")
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
raise DeserializationError(f"Could not locate the streaming callback: {function_name}")
data["init_parameters"]["streaming_callback"] = streaming_callback
return default_from_dict(cls, data)
@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
def run(self, prompt: str):
"""
Queries the LLM with the prompts to produce replies.
:param prompts: The prompts to be sent to the generative model.
"""
message = _ChatMessage(content=prompt, role="user")
if self.system_prompt:
chat = [_ChatMessage(content=self.system_prompt, role="system"), message]
else:
chat = [message]
completion = openai.ChatCompletion.create(
model=self.model_name,
messages=[asdict(message) for message in chat],
stream=self.streaming_callback is not None,
**self.model_parameters,
)
replies: List[str]
metadata: List[Dict[str, Any]]
if self.streaming_callback:
replies_dict: Dict[str, str] = defaultdict(str)
metadata_dict: Dict[str, Dict[str, Any]] = defaultdict(dict)
for chunk in completion:
chunk = self.streaming_callback(chunk)
for choice in chunk.choices:
if hasattr(choice.delta, "content"):
replies_dict[choice.index] += choice.delta.content
metadata_dict[choice.index] = {
"model": chunk.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
}
replies = list(replies_dict.values())
metadata = list(metadata_dict.values())
self._check_truncated_answers(metadata)
return {"replies": replies, "metadata": metadata}
metadata = [
{
"model": completion.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"usage": dict(completion.usage.items()),
}
for choice in completion.choices
]
replies = [choice.message.content.strip() for choice in completion.choices]
self._check_truncated_answers(metadata)
return {"replies": replies, "metadata": metadata}
def _check_truncated_answers(self, metadata: List[Dict[str, Any]]):
"""
Check the `finish_reason` returned with the OpenAI completions.
If the `finish_reason` is `length`, log a warning to the user.
:param result: The result returned from the OpenAI API.
:param payload: The payload sent to the OpenAI API.
"""
truncated_completions = sum(1 for meta in metadata if meta.get("finish_reason") != "stop")
if truncated_completions > 0:
logger.warning(
"%s out of the %s completions have been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions.",
truncated_completions,
len(metadata),
)

View File

@ -6,6 +6,14 @@ from haystack.preview import DeserializationError
from haystack.preview.dataclasses import StreamingChunk
def default_streaming_callback(chunk: StreamingChunk) -> None:
"""
Default callback function for streaming responses.
Prints the tokens of the first completion to stdout as soon as they are received
"""
print(chunk.content, flush=True, end="")
def serialize_callback_handler(streaming_callback: Callable[[StreamingChunk], None]) -> str:
"""
Serializes the streaming callback handler.

View File

@ -0,0 +1,4 @@
---
preview:
- |
Adapt GPTGenerator to use strings for input and output

View File

@ -1,334 +0,0 @@
import os
from unittest.mock import patch, Mock
from copy import deepcopy
import pytest
import openai
from openai.util import convert_to_openai_object
from haystack.preview.components.generators.openai.gpt import GPTGenerator
from haystack.preview.components.generators.openai.gpt import default_streaming_callback
def mock_openai_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion:
response = f"response for these messages --> {' - '.join(msg['role']+': '+msg['content'] for msg in messages)}"
base_dict = {
"id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V",
"object": "chat.completion",
"created": 1685855844,
"model": model,
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
}
base_dict["choices"] = [
{"message": {"role": "assistant", "content": response}, "finish_reason": "stop", "index": "0"}
]
return convert_to_openai_object(deepcopy(base_dict))
def mock_openai_stream_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion:
response = f"response for these messages --> {' - '.join(msg['role']+': '+msg['content'] for msg in messages)}"
base_dict = {
"id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V",
"object": "chat.completion",
"created": 1685855844,
"model": model,
}
base_dict["choices"] = [{"delta": {"role": "assistant"}, "finish_reason": None, "index": "0"}]
yield convert_to_openai_object(base_dict)
for token in response.split():
base_dict["choices"][0]["delta"] = {"content": token + " "}
yield convert_to_openai_object(base_dict)
base_dict["choices"] = [{"delta": {"content": ""}, "finish_reason": "stop", "index": "0"}]
yield convert_to_openai_object(base_dict)
class TestGPTGenerator:
@pytest.mark.unit
def test_init_default(self):
component = GPTGenerator(api_key="test-api-key")
assert openai.api_key == "test-api-key"
assert component.system_prompt is None
assert component.model_name == "gpt-3.5-turbo"
assert component.streaming_callback is None
assert component.api_base_url == "https://api.openai.com/v1"
assert openai.api_base == "https://api.openai.com/v1"
assert component.model_parameters == {}
@pytest.mark.unit
def test_init_fail_wo_api_key(self, monkeypatch):
openai.api_key = None
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"):
GPTGenerator()
@pytest.mark.unit
def test_init_with_parameters(self):
callback = lambda x: x
component = GPTGenerator(
api_key="test-api-key",
model_name="gpt-4",
system_prompt="test-system-prompt",
max_tokens=10,
some_test_param="test-params",
streaming_callback=callback,
api_base_url="test-base-url",
)
assert openai.api_key == "test-api-key"
assert component.system_prompt == "test-system-prompt"
assert component.model_name == "gpt-4"
assert component.streaming_callback == callback
assert component.api_base_url == "test-base-url"
assert openai.api_base == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
@pytest.mark.unit
def test_to_dict_default(self):
component = GPTGenerator(api_key="test-api-key")
data = component.to_dict()
assert data == {
"type": "GPTGenerator",
"init_parameters": {
"model_name": "gpt-3.5-turbo",
"system_prompt": None,
"streaming_callback": None,
"api_base_url": "https://api.openai.com/v1",
},
}
@pytest.mark.unit
def test_to_dict_with_parameters(self):
component = GPTGenerator(
api_key="test-api-key",
model_name="gpt-4",
system_prompt="test-system-prompt",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "GPTGenerator",
"init_parameters": {
"model_name": "gpt-4",
"system_prompt": "test-system-prompt",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback",
},
}
@pytest.mark.unit
def test_to_dict_with_lambda_streaming_callback(self):
component = GPTGenerator(
api_key="test-api-key",
model_name="gpt-4",
system_prompt="test-system-prompt",
max_tokens=10,
some_test_param="test-params",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "GPTGenerator",
"init_parameters": {
"model_name": "gpt-4",
"system_prompt": "test-system-prompt",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "test_gpt_generator.<lambda>",
},
}
@pytest.mark.unit
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
data = {
"type": "GPTGenerator",
"init_parameters": {
"model_name": "gpt-4",
"system_prompt": "test-system-prompt",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback",
},
}
component = GPTGenerator.from_dict(data)
assert component.system_prompt == "test-system-prompt"
assert component.model_name == "gpt-4"
assert component.streaming_callback == default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
@pytest.mark.unit
def test_from_dict_fail_wo_env_var(self, monkeypatch):
openai.api_key = None
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "GPTGenerator",
"init_parameters": {
"model_name": "gpt-4",
"system_prompt": "test-system-prompt",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback",
},
}
with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"):
GPTGenerator.from_dict(data)
@pytest.mark.unit
def test_run_no_system_prompt(self):
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
gpt_patch.create.side_effect = mock_openai_response
component = GPTGenerator(api_key="test-api-key")
results = component.run(prompt="test-prompt-1")
assert results == {
"replies": ["response for these messages --> user: test-prompt-1"],
"metadata": [
{
"model": "gpt-3.5-turbo",
"index": "0",
"finish_reason": "stop",
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
}
],
}
gpt_patch.create.assert_called_once_with(
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "test-prompt-1"}], stream=False
)
@pytest.mark.unit
def test_run_with_system_prompt(self):
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
gpt_patch.create.side_effect = mock_openai_response
component = GPTGenerator(api_key="test-api-key", system_prompt="test-system-prompt")
results = component.run(prompt="test-prompt-1")
assert results == {
"replies": ["response for these messages --> system: test-system-prompt - user: test-prompt-1"],
"metadata": [
{
"model": "gpt-3.5-turbo",
"index": "0",
"finish_reason": "stop",
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
}
],
}
gpt_patch.create.assert_called_once_with(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "test-system-prompt"},
{"role": "user", "content": "test-prompt-1"},
],
stream=False,
)
@pytest.mark.unit
def test_run_with_parameters(self):
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
gpt_patch.create.side_effect = mock_openai_response
component = GPTGenerator(api_key="test-api-key", max_tokens=10)
component.run(prompt="test-prompt-1")
gpt_patch.create.assert_called_once_with(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "test-prompt-1"}],
stream=False,
max_tokens=10,
)
@pytest.mark.unit
def test_run_stream(self):
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
mock_callback = Mock()
mock_callback.side_effect = default_streaming_callback
gpt_patch.create.side_effect = mock_openai_stream_response
component = GPTGenerator(
api_key="test-api-key", system_prompt="test-system-prompt", streaming_callback=mock_callback
)
results = component.run(prompt="test-prompt-1")
assert results == {
"replies": ["response for these messages --> system: test-system-prompt - user: test-prompt-1 "],
"metadata": [{"model": "gpt-3.5-turbo", "index": "0", "finish_reason": "stop"}],
}
# Calls count: 10 tokens per prompt + 1 token for the role + 1 empty termination token
assert mock_callback.call_count == 12
gpt_patch.create.assert_called_once_with(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "test-system-prompt"},
{"role": "user", "content": "test-prompt-1"},
],
stream=True,
)
@pytest.mark.unit
def test_check_truncated_answers(self, caplog):
component = GPTGenerator(api_key="test-api-key")
metadata = [
{"finish_reason": "stop"},
{"finish_reason": "content_filter"},
{"finish_reason": "length"},
{"finish_reason": "stop"},
]
component._check_truncated_answers(metadata)
assert caplog.records[0].message == (
"2 out of the 4 completions have been truncated before reaching a natural "
"stopping point. Increase the max_tokens parameter to allow for longer completions."
)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_gpt_generator_run(self):
component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
results = component.run(prompt="What's the capital of France?")
assert len(results["replies"]) == 1
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert "gpt-3.5" in results["metadata"][0]["model"]
assert results["metadata"][0]["finish_reason"] == "stop"
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_gpt_generator_run_wrong_model_name(self):
component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1)
with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"):
component.run(prompt="What's the capital of France?")
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_gpt_generator_run_streaming(self):
class Callback:
def __init__(self):
self.responses = ""
def __call__(self, chunk):
self.responses += chunk.choices[0].delta.content if chunk.choices[0].delta else ""
return chunk
callback = Callback()
component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1)
results = component.run(prompt="What's the capital of France?")
assert len(results["replies"]) == 1
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert "gpt-3.5" in results["metadata"][0]["model"]
assert results["metadata"][0]["finish_reason"] == "stop"
assert callback.responses == results["replies"][0]

View File

@ -0,0 +1,348 @@
import os
from typing import List
from unittest.mock import patch, Mock
import openai
import pytest
from haystack.preview.components.generators import GPTGenerator
from haystack.preview.components.generators.utils import default_streaming_callback
from haystack.preview.dataclasses import StreamingChunk, ChatMessage
@pytest.fixture
def mock_chat_completion():
"""
Mock the OpenAI API completion response and reuse it for tests
"""
with patch("openai.ChatCompletion.create", autospec=True) as mock_chat_completion_create:
# mimic the response from the OpenAI API
mock_choice = Mock()
mock_choice.index = 0
mock_choice.finish_reason = "stop"
mock_message = Mock()
mock_message.content = "I'm fine, thanks. How are you?"
mock_message.role = "user"
mock_choice.message = mock_message
mock_response = Mock()
mock_response.model = "gpt-3.5-turbo"
mock_response.usage = Mock()
mock_response.usage.items.return_value = [
("prompt_tokens", 57),
("completion_tokens", 40),
("total_tokens", 97),
]
mock_response.choices = [mock_choice]
mock_chat_completion_create.return_value = mock_response
yield mock_chat_completion_create
def streaming_chunk(content: str):
"""
Mock chunks of streaming responses from the OpenAI API
"""
# mimic the chunk response from the OpenAI API
mock_choice = Mock()
mock_choice.index = 0
mock_choice.delta.content = content
mock_choice.finish_reason = "stop"
mock_response = Mock()
mock_response.choices = [mock_choice]
mock_response.model = "gpt-3.5-turbo"
mock_response.usage = Mock()
mock_response.usage.items.return_value = [("prompt_tokens", 57), ("completion_tokens", 40), ("total_tokens", 97)]
return mock_response
class TestGPTGenerator:
@pytest.mark.unit
def test_init_default(self):
component = GPTGenerator(api_key="test-api-key")
assert openai.api_key == "test-api-key"
assert component.model_name == "gpt-3.5-turbo"
assert component.streaming_callback is None
assert component.api_base_url == "https://api.openai.com/v1"
assert openai.api_base == "https://api.openai.com/v1"
assert not component.generation_kwargs
@pytest.mark.unit
def test_init_fail_wo_api_key(self, monkeypatch):
openai.api_key = None
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"):
GPTGenerator()
@pytest.mark.unit
def test_init_with_parameters(self):
component = GPTGenerator(
api_key="test-api-key",
model_name="gpt-4",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
)
assert openai.api_key == "test-api-key"
assert component.model_name == "gpt-4"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
assert openai.api_base == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
@pytest.mark.unit
def test_to_dict_default(self):
component = GPTGenerator(api_key="test-api-key")
data = component.to_dict()
assert data == {
"type": "GPTGenerator",
"init_parameters": {
"model_name": "gpt-3.5-turbo",
"streaming_callback": None,
"system_prompt": None,
"api_base_url": "https://api.openai.com/v1",
},
}
@pytest.mark.unit
def test_to_dict_with_parameters(self):
component = GPTGenerator(
api_key="test-api-key",
model_name="gpt-4",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "GPTGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"system_prompt": None,
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
},
}
@pytest.mark.unit
def test_to_dict_with_lambda_streaming_callback(self):
component = GPTGenerator(
api_key="test-api-key",
model_name="gpt-4",
max_tokens=10,
some_test_param="test-params",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "GPTGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"system_prompt": None,
"api_base_url": "test-base-url",
"streaming_callback": "test_openai.<lambda>",
},
}
@pytest.mark.unit
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
data = {
"type": "GPTGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"system_prompt": None,
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
},
}
component = GPTGenerator.from_dict(data)
assert component.model_name == "gpt-4"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
@pytest.mark.unit
def test_from_dict_fail_wo_env_var(self, monkeypatch):
openai.api_key = None
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "GPTGenerator",
"init_parameters": {
"model_name": "gpt-4",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
},
}
with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"):
GPTGenerator.from_dict(data)
@pytest.mark.unit
def test_run(self, mock_chat_completion):
component = GPTGenerator(api_key="test-api-key")
response = component.run("What's Natural Language Processing?")
# check that the component returns the correct ChatMessage response
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, str) for reply in response["replies"]]
def test_run_with_params(self, mock_chat_completion):
component = GPTGenerator(api_key="test-api-key", max_tokens=10, temperature=0.5)
response = component.run("What's Natural Language Processing?")
# check that the component calls the OpenAI API with the correct parameters
_, kwargs = mock_chat_completion.call_args
assert kwargs["max_tokens"] == 10
assert kwargs["temperature"] == 0.5
# check that the component returns the correct response
assert isinstance(response, dict)
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) == 1
assert [isinstance(reply, str) for reply in response["replies"]]
@pytest.mark.unit
def test_run_streaming(self, mock_chat_completion):
streaming_call_count = 0
# Define the streaming callback function and assert that it is called with StreamingChunk objects
def streaming_callback_fn(chunk: StreamingChunk):
nonlocal streaming_call_count
streaming_call_count += 1
assert isinstance(chunk, StreamingChunk)
generator = GPTGenerator(api_key="test-api-key", streaming_callback=streaming_callback_fn)
# Create a fake streamed response
# self needed here, don't remove
def mock_iter(self):
yield streaming_chunk("Hello")
yield streaming_chunk("How are you?")
mock_response = Mock(**{"__iter__": mock_iter})
mock_chat_completion.return_value = mock_response
response = generator.run("Hello there")
# Assert that the streaming callback was called twice
assert streaming_call_count == 2
# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
assert len(response["replies"]) > 0
assert [isinstance(reply, str) for reply in response["replies"]]
@pytest.mark.unit
def test_check_abnormal_completions(self, caplog):
component = GPTGenerator(api_key="test-api-key")
# underlying implementation uses ChatMessage objects so we have to use them here
messages: List[ChatMessage] = []
for i, _ in enumerate(range(4)):
message = ChatMessage.from_assistant("Hello")
metadata = {"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
message.metadata.update(metadata)
messages.append(message)
for m in messages:
component._check_finish_reason(m)
# check truncation warning
message_template = (
"The completion for index {index} has been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions."
)
for index in [1, 3]:
assert caplog.records[index].message == message_template.format(index=index)
# check content filter warning
message_template = "The completion for index {index} has been truncated due to the content filter."
for index in [0, 2]:
assert caplog.records[index].message == message_template.format(index=index)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run(self):
component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
results = component.run("What's the capital of France?")
assert len(results["replies"]) == 1
assert len(results["metadata"]) == 1
response: str = results["replies"][0]
assert "Paris" in response
metadata = results["metadata"][0]
assert "gpt-3.5" in metadata["model"]
assert metadata["finish_reason"] == "stop"
assert "usage" in metadata
assert "prompt_tokens" in metadata["usage"] and metadata["usage"]["prompt_tokens"] > 0
assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0
assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_wrong_model(self):
component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"):
component.run("Whatever")
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_streaming(self):
class Callback:
def __init__(self):
self.responses = ""
self.counter = 0
def __call__(self, chunk: StreamingChunk) -> None:
self.counter += 1
self.responses += chunk.content if chunk.content else ""
callback = Callback()
component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
results = component.run("What's the capital of France?")
assert len(results["replies"]) == 1
assert len(results["metadata"]) == 1
response: str = results["replies"][0]
assert "Paris" in response
metadata = results["metadata"][0]
assert "gpt-3.5" in metadata["model"]
assert metadata["finish_reason"] == "stop"
# unfortunately, the usage is not available for streaming calls
# we keep the key in the metadata for compatibility
assert "usage" in metadata and len(metadata["usage"]) == 0
assert callback.counter > 1
assert "Paris" in callback.responses

View File

@ -1,6 +1,6 @@
import pytest
from haystack.preview.components.generators.openai.gpt import default_streaming_callback
from haystack.preview.components.generators.utils import default_streaming_callback
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
@ -18,7 +18,7 @@ def test_callback_handler_serialization():
@pytest.mark.unit
def test_callback_handler_serialization_non_local():
result = serialize_callback_handler(default_streaming_callback)
assert result == "haystack.preview.components.generators.openai.gpt.default_streaming_callback"
assert result == "haystack.preview.components.generators.utils.default_streaming_callback"
@pytest.mark.unit