diff --git a/e2e/preview/pipelines/test_rag_pipelines.py b/e2e/preview/pipelines/test_rag_pipelines.py index e9835a671..dc17b03f5 100644 --- a/e2e/preview/pipelines/test_rag_pipelines.py +++ b/e2e/preview/pipelines/test_rag_pipelines.py @@ -7,7 +7,7 @@ from haystack.preview.document_stores import InMemoryDocumentStore from haystack.preview.components.writers import DocumentWriter from haystack.preview.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever from haystack.preview.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder -from haystack.preview.components.generators.openai.gpt import GPTGenerator +from haystack.preview.components.generators import GPTGenerator from haystack.preview.components.builders.answer_builder import AnswerBuilder from haystack.preview.components.builders.prompt_builder import PromptBuilder diff --git a/haystack/preview/components/generators/__init__.py b/haystack/preview/components/generators/__init__.py index 331f49a0a..761b977d2 100644 --- a/haystack/preview/components/generators/__init__.py +++ b/haystack/preview/components/generators/__init__.py @@ -1,5 +1,5 @@ -from haystack.preview.components.generators.openai.gpt import GPTGenerator from haystack.preview.components.generators.hugging_face.hugging_face_local import HuggingFaceLocalGenerator from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator +from haystack.preview.components.generators.openai import GPTGenerator -__all__ = ["GPTGenerator", "HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator"] +__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator"] diff --git a/haystack/preview/components/generators/openai.py b/haystack/preview/components/generators/openai.py new file mode 100644 index 000000000..833de129a --- /dev/null +++ b/haystack/preview/components/generators/openai.py @@ -0,0 +1,290 @@ +import dataclasses +import logging +import os +from typing import Optional, List, Callable, Dict, Any + +import openai +from openai.openai_object import OpenAIObject + +from haystack.preview import component, default_from_dict, default_to_dict +from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler +from haystack.preview.dataclasses import StreamingChunk, ChatMessage + +logger = logging.getLogger(__name__) + + +API_BASE_URL = "https://api.openai.com/v1" + + +@component +class GPTGenerator: + """ + Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo + family of models. + + Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method + directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs` + parameter in `run` method. + + For more details on the parameters supported by the OpenAI API, refer to the OpenAI + [documentation](https://platform.openai.com/docs/api-reference/chat). + + ```python + from haystack.preview.components.generators import GPTGenerator + client = GPTGenerator() + response = client.run("What's Natural Language Processing? Be brief.") + print(response) + + >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on + >> the interaction between computers and human language. It involves enabling computers to understand, interpret, + >> and respond to natural human language in a way that is both meaningful and useful.'], 'metadata': [{'model': + >> 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 16, + >> 'completion_tokens': 49, 'total_tokens': 65}}]} + ``` + + Key Features and Compatibility: + - **Primary Compatibility**: Designed to work seamlessly with gpt-4, gpt-3.5-turbo family of models. + - **Streaming Support**: Supports streaming responses from the OpenAI API. + - **Customizability**: Supports all parameters supported by the OpenAI API. + + Input and Output Format: + - **String Format**: This component uses the strings for both input and output. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gpt-3.5-turbo", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + api_base_url: str = API_BASE_URL, + system_prompt: Optional[str] = None, + **generation_kwargs, + ): + """ + Creates an instance of ChatGPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's + GPT-3.5 model. + + :param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the + environment variable OPENAI_API_KEY (recommended). + :param model_name: The name of the model to use. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param system_prompt: The system prompt to use for text generation. If not provided, the system prompt is + omitted, and the default system prompt of the model is used. + :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to + the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for + more details. + Some of the supported parameters: + - `max_tokens`: The maximum number of tokens the output text can have. + - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks. + Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer. + - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the tokens with top_p probability mass. So, 0.1 means only the tokens + comprising the top 10% probability mass are considered. + - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, + it will generate two completions for each of the three prompts, ending up with 6 completions in total. + - `stop`: One or more sequences after which the LLM should stop generating tokens. + - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean + the model will be less likely to repeat the same token in the text. + - `frequency_penalty`: What penalty to apply if a token has already been generated in the text. + Bigger values mean the model will be less likely to repeat the same token in the text. + - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the + values are the bias to add to that token. + """ + # if the user does not provide the API key, check if it is set in the module client + api_key = api_key or openai.api_key + if api_key is None: + try: + api_key = os.environ["OPENAI_API_KEY"] + except KeyError as e: + raise ValueError( + "GPTGenerator expects an OpenAI API key. " + "Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly." + ) from e + openai.api_key = api_key + + self.model_name = model_name + self.generation_kwargs = generation_kwargs + self.system_prompt = system_prompt + self.streaming_callback = streaming_callback + + self.api_base_url = api_base_url + openai.api_base = api_base_url + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model_name} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + :return: The serialized component as a dictionary. + """ + callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + model_name=self.model_name, + streaming_callback=callback_name, + api_base_url=self.api_base_url, + **self.generation_kwargs, + system_prompt=self.system_prompt, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + return default_from_dict(cls, data) + + @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Invoke the text generation inference based on the provided messages and generation parameters. + + :param prompt: The string prompt to use for text generation. + :param generation_kwargs: Additional keyword arguments for text generation. These parameters will + potentially override the parameters passed in the __init__ method. + For more details on the parameters supported by the OpenAI API, refer to the + OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create). + :return: A list of strings containing the generated responses and a list of dictionaries containing the metadata + for each response. + """ + message = ChatMessage.from_user(prompt) + if self.system_prompt: + messages = [ChatMessage.from_system(self.system_prompt), message] + else: + messages = [message] + + # update generation kwargs by merging with the generation kwargs passed to the run method + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + # adapt ChatMessage(s) to the format expected by the OpenAI API + openai_formatted_messages = self._convert_to_openai_format(messages) + + completion = openai.ChatCompletion.create( + model=self.model_name, + messages=openai_formatted_messages, + stream=self.streaming_callback is not None, + **generation_kwargs, + ) + + completions: List[ChatMessage] + if self.streaming_callback: + num_responses = generation_kwargs.pop("n", 1) + if num_responses > 1: + raise ValueError("Cannot stream multiple responses, please set n=1.") + chunks: List[StreamingChunk] = [] + chunk = None + for chunk in completion: + if chunk.choices: + chunk_delta: StreamingChunk = self._build_chunk(chunk, chunk.choices[0]) + chunks.append(chunk_delta) + self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta + completions = [self._connect_chunks(chunk, chunks)] + else: + completions = [self._build_message(completion, choice) for choice in completion.choices] + + # before returning, do post-processing of the completions + for completion in completions: + self._check_finish_reason(completion) + + return { + "replies": [message.content for message in completions], + "metadata": [message.metadata for message in completions], + } + + def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: + """ + Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API. + :param messages: The list of ChatMessage. + :return: The list of messages in the format expected by the OpenAI API. + """ + openai_chat_message_format = {"role", "content", "name"} + openai_formatted_messages = [] + for m in messages: + message_dict = dataclasses.asdict(m) + filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v} + openai_formatted_messages.append(filtered_message) + return openai_formatted_messages + + def _connect_chunks(self, chunk: OpenAIObject, chunks: List[StreamingChunk]) -> ChatMessage: + """ + Connects the streaming chunks into a single ChatMessage. + """ + complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks])) + complete_response.metadata.update( + { + "model": chunk.model, + "index": 0, + "finish_reason": chunk.choices[0].finish_reason, + "usage": {}, # we don't have usage data for streaming responses + } + ) + return complete_response + + def _build_message(self, completion: OpenAIObject, choice: OpenAIObject) -> ChatMessage: + """ + Converts the response from the OpenAI API to a ChatMessage. + :param completion: The completion returned by the OpenAI API. + :param choice: The choice returned by the OpenAI API. + :return: The ChatMessage. + """ + message: OpenAIObject = choice.message + content = dict(message.function_call) if choice.finish_reason == "function_call" else message.content + chat_message = ChatMessage.from_assistant(content) + chat_message.metadata.update( + { + "model": completion.model, + "index": choice.index, + "finish_reason": choice.finish_reason, + "usage": dict(completion.usage.items()), + } + ) + return chat_message + + def _build_chunk(self, chunk: OpenAIObject, choice: OpenAIObject) -> StreamingChunk: + """ + Converts the response from the OpenAI API to a StreamingChunk. + :param chunk: The chunk returned by the OpenAI API. + :param choice: The choice returned by the OpenAI API. + :return: The StreamingChunk. + """ + has_content = bool(hasattr(choice.delta, "content") and choice.delta.content) + if has_content: + content = choice.delta.content + elif hasattr(choice.delta, "function_call"): + content = str(choice.delta.function_call) + else: + content = "" + chunk_message = StreamingChunk(content) + chunk_message.metadata.update( + {"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason} + ) + return chunk_message + + def _check_finish_reason(self, message: ChatMessage) -> None: + """ + Check the `finish_reason` returned with the OpenAI completions. + If the `finish_reason` is `length`, log a warning to the user. + :param message: The message returned by the LLM. + """ + if message.metadata["finish_reason"] == "length": + logger.warning( + "The completion for index %s has been truncated before reaching a natural stopping point. " + "Increase the max_tokens parameter to allow for longer completions.", + message.metadata["index"], + ) + if message.metadata["finish_reason"] == "content_filter": + logger.warning( + "The completion for index %s has been truncated due to the content filter.", message.metadata["index"] + ) diff --git a/haystack/preview/components/generators/openai/__init__.py b/haystack/preview/components/generators/openai/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/haystack/preview/components/generators/openai/gpt.py b/haystack/preview/components/generators/openai/gpt.py deleted file mode 100644 index 56f8c3582..000000000 --- a/haystack/preview/components/generators/openai/gpt.py +++ /dev/null @@ -1,225 +0,0 @@ -from typing import Optional, List, Callable, Dict, Any - -import sys -import logging -from collections import defaultdict -from dataclasses import dataclass, asdict -import os - -import openai - -from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError - - -logger = logging.getLogger(__name__) - - -API_BASE_URL = "https://api.openai.com/v1" - - -@dataclass -class _ChatMessage: - content: str - role: str - - -def default_streaming_callback(chunk): - """ - Default callback function for streaming responses from OpenAI API. - Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged. - """ - if hasattr(chunk.choices[0].delta, "content"): - print(chunk.choices[0].delta.content, flush=True, end="") - return chunk - - -@component -class GPTGenerator: - """ - LLM Generator compatible with GPT (ChatGPT) large language models. - - Queries the LLM using OpenAI's API. Invocations are made using OpenAI SDK ('openai' package) - See [OpenAI GPT API](https://platform.openai.com/docs/guides/chat) for more details. - """ - - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "gpt-3.5-turbo", - system_prompt: Optional[str] = None, - streaming_callback: Optional[Callable] = None, - api_base_url: str = API_BASE_URL, - **kwargs, - ): - """ - Creates an instance of GPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's GPT-3.5 model. - - :param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the - environment variable OPENAI_API_KEY (recommended). - :param model_name: The name of the model to use. - :param system_prompt: An additional message to be sent to the LLM at the beginning of each conversation. - Typically, a conversation is formatted with a system message first, followed by alternating messages from - the 'user' (the "queries") and the 'assistant' (the "responses"). The system message helps set the behavior - of the assistant. For example, you can modify the personality of the assistant or provide specific - instructions about how it should behave throughout the conversation. - :param streaming_callback: A callback function that is called when a new token is received from the stream. - The callback function should accept two parameters: the token received from the stream and **kwargs. - The callback function should return the token to be sent to the stream. If the callback function is not - provided, the token is printed to stdout. - :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. - :param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI - endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details. - Some of the supported parameters: - - `max_tokens`: The maximum number of tokens the output text can have. - - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks. - Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. - - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model - considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens - comprising the top 10% probability mass are considered. - - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, - it will generate two completions for each of the three prompts, ending up with 6 completions in total. - - `stop`: One or more sequences after which the LLM should stop generating tokens. - - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean - the model will be less likely to repeat the same token in the text. - - `frequency_penalty`: What penalty to apply if a token has already been generated in the text. - Bigger values mean the model will be less likely to repeat the same token in the text. - - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the - values are the bias to add to that token. - """ - # if the user does not provide the API key, check if it is set in the module client - api_key = api_key or openai.api_key - if api_key is None: - try: - api_key = os.environ["OPENAI_API_KEY"] - except KeyError as e: - raise ValueError( - "GPTGenerator expects an OpenAI API key. " - "Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly." - ) from e - openai.api_key = api_key - - self.model_name = model_name - self.system_prompt = system_prompt - self.model_parameters = kwargs - self.streaming_callback = streaming_callback - - self.api_base_url = api_base_url - openai.api_base = api_base_url - - def _get_telemetry_data(self) -> Dict[str, Any]: - """ - Data that is sent to Posthog for usage analytics. - """ - return {"model": self.model_name} - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize this component to a dictionary. - """ - if self.streaming_callback: - module = self.streaming_callback.__module__ - if module == "builtins": - callback_name = self.streaming_callback.__name__ - else: - callback_name = f"{module}.{self.streaming_callback.__name__}" - else: - callback_name = None - - return default_to_dict( - self, - model_name=self.model_name, - system_prompt=self.system_prompt, - streaming_callback=callback_name, - api_base_url=self.api_base_url, - **self.model_parameters, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator": - """ - Deserialize this component from a dictionary. - """ - init_params = data.get("init_parameters", {}) - streaming_callback = None - if "streaming_callback" in init_params and init_params["streaming_callback"]: - parts = init_params["streaming_callback"].split(".") - module_name = ".".join(parts[:-1]) - function_name = parts[-1] - module = sys.modules.get(module_name, None) - if not module: - raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}") - streaming_callback = getattr(module, function_name, None) - if not streaming_callback: - raise DeserializationError(f"Could not locate the streaming callback: {function_name}") - data["init_parameters"]["streaming_callback"] = streaming_callback - return default_from_dict(cls, data) - - @component.output_types(replies=List[str], metadata=List[Dict[str, Any]]) - def run(self, prompt: str): - """ - Queries the LLM with the prompts to produce replies. - - :param prompts: The prompts to be sent to the generative model. - """ - message = _ChatMessage(content=prompt, role="user") - if self.system_prompt: - chat = [_ChatMessage(content=self.system_prompt, role="system"), message] - else: - chat = [message] - - completion = openai.ChatCompletion.create( - model=self.model_name, - messages=[asdict(message) for message in chat], - stream=self.streaming_callback is not None, - **self.model_parameters, - ) - - replies: List[str] - metadata: List[Dict[str, Any]] - if self.streaming_callback: - replies_dict: Dict[str, str] = defaultdict(str) - metadata_dict: Dict[str, Dict[str, Any]] = defaultdict(dict) - for chunk in completion: - chunk = self.streaming_callback(chunk) - for choice in chunk.choices: - if hasattr(choice.delta, "content"): - replies_dict[choice.index] += choice.delta.content - metadata_dict[choice.index] = { - "model": chunk.model, - "index": choice.index, - "finish_reason": choice.finish_reason, - } - replies = list(replies_dict.values()) - metadata = list(metadata_dict.values()) - self._check_truncated_answers(metadata) - return {"replies": replies, "metadata": metadata} - - metadata = [ - { - "model": completion.model, - "index": choice.index, - "finish_reason": choice.finish_reason, - "usage": dict(completion.usage.items()), - } - for choice in completion.choices - ] - replies = [choice.message.content.strip() for choice in completion.choices] - self._check_truncated_answers(metadata) - return {"replies": replies, "metadata": metadata} - - def _check_truncated_answers(self, metadata: List[Dict[str, Any]]): - """ - Check the `finish_reason` returned with the OpenAI completions. - If the `finish_reason` is `length`, log a warning to the user. - - :param result: The result returned from the OpenAI API. - :param payload: The payload sent to the OpenAI API. - """ - truncated_completions = sum(1 for meta in metadata if meta.get("finish_reason") != "stop") - if truncated_completions > 0: - logger.warning( - "%s out of the %s completions have been truncated before reaching a natural stopping point. " - "Increase the max_tokens parameter to allow for longer completions.", - truncated_completions, - len(metadata), - ) diff --git a/haystack/preview/components/generators/utils.py b/haystack/preview/components/generators/utils.py index 9db2fd488..397009e4e 100644 --- a/haystack/preview/components/generators/utils.py +++ b/haystack/preview/components/generators/utils.py @@ -6,6 +6,14 @@ from haystack.preview import DeserializationError from haystack.preview.dataclasses import StreamingChunk +def default_streaming_callback(chunk: StreamingChunk) -> None: + """ + Default callback function for streaming responses. + Prints the tokens of the first completion to stdout as soon as they are received + """ + print(chunk.content, flush=True, end="") + + def serialize_callback_handler(streaming_callback: Callable[[StreamingChunk], None]) -> str: """ Serializes the streaming callback handler. diff --git a/releasenotes/notes/adapt-gpt-generator-bb7f52bd67f6b197.yaml b/releasenotes/notes/adapt-gpt-generator-bb7f52bd67f6b197.yaml new file mode 100644 index 000000000..8382ecf14 --- /dev/null +++ b/releasenotes/notes/adapt-gpt-generator-bb7f52bd67f6b197.yaml @@ -0,0 +1,4 @@ +--- +preview: + - | + Adapt GPTGenerator to use strings for input and output diff --git a/test/preview/components/generators/openai/test_gpt_generator.py b/test/preview/components/generators/openai/test_gpt_generator.py deleted file mode 100644 index c30432f48..000000000 --- a/test/preview/components/generators/openai/test_gpt_generator.py +++ /dev/null @@ -1,334 +0,0 @@ -import os -from unittest.mock import patch, Mock -from copy import deepcopy - -import pytest -import openai -from openai.util import convert_to_openai_object - -from haystack.preview.components.generators.openai.gpt import GPTGenerator -from haystack.preview.components.generators.openai.gpt import default_streaming_callback - - -def mock_openai_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion: - response = f"response for these messages --> {' - '.join(msg['role']+': '+msg['content'] for msg in messages)}" - base_dict = { - "id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V", - "object": "chat.completion", - "created": 1685855844, - "model": model, - "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, - } - base_dict["choices"] = [ - {"message": {"role": "assistant", "content": response}, "finish_reason": "stop", "index": "0"} - ] - return convert_to_openai_object(deepcopy(base_dict)) - - -def mock_openai_stream_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion: - response = f"response for these messages --> {' - '.join(msg['role']+': '+msg['content'] for msg in messages)}" - base_dict = { - "id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V", - "object": "chat.completion", - "created": 1685855844, - "model": model, - } - base_dict["choices"] = [{"delta": {"role": "assistant"}, "finish_reason": None, "index": "0"}] - yield convert_to_openai_object(base_dict) - for token in response.split(): - base_dict["choices"][0]["delta"] = {"content": token + " "} - yield convert_to_openai_object(base_dict) - base_dict["choices"] = [{"delta": {"content": ""}, "finish_reason": "stop", "index": "0"}] - yield convert_to_openai_object(base_dict) - - -class TestGPTGenerator: - @pytest.mark.unit - def test_init_default(self): - component = GPTGenerator(api_key="test-api-key") - assert openai.api_key == "test-api-key" - assert component.system_prompt is None - assert component.model_name == "gpt-3.5-turbo" - assert component.streaming_callback is None - assert component.api_base_url == "https://api.openai.com/v1" - assert openai.api_base == "https://api.openai.com/v1" - assert component.model_parameters == {} - - @pytest.mark.unit - def test_init_fail_wo_api_key(self, monkeypatch): - openai.api_key = None - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"): - GPTGenerator() - - @pytest.mark.unit - def test_init_with_parameters(self): - callback = lambda x: x - component = GPTGenerator( - api_key="test-api-key", - model_name="gpt-4", - system_prompt="test-system-prompt", - max_tokens=10, - some_test_param="test-params", - streaming_callback=callback, - api_base_url="test-base-url", - ) - assert openai.api_key == "test-api-key" - assert component.system_prompt == "test-system-prompt" - assert component.model_name == "gpt-4" - assert component.streaming_callback == callback - assert component.api_base_url == "test-base-url" - assert openai.api_base == "test-base-url" - assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} - - @pytest.mark.unit - def test_to_dict_default(self): - component = GPTGenerator(api_key="test-api-key") - data = component.to_dict() - assert data == { - "type": "GPTGenerator", - "init_parameters": { - "model_name": "gpt-3.5-turbo", - "system_prompt": None, - "streaming_callback": None, - "api_base_url": "https://api.openai.com/v1", - }, - } - - @pytest.mark.unit - def test_to_dict_with_parameters(self): - component = GPTGenerator( - api_key="test-api-key", - model_name="gpt-4", - system_prompt="test-system-prompt", - max_tokens=10, - some_test_param="test-params", - streaming_callback=default_streaming_callback, - api_base_url="test-base-url", - ) - data = component.to_dict() - assert data == { - "type": "GPTGenerator", - "init_parameters": { - "model_name": "gpt-4", - "system_prompt": "test-system-prompt", - "max_tokens": 10, - "some_test_param": "test-params", - "api_base_url": "test-base-url", - "streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback", - }, - } - - @pytest.mark.unit - def test_to_dict_with_lambda_streaming_callback(self): - component = GPTGenerator( - api_key="test-api-key", - model_name="gpt-4", - system_prompt="test-system-prompt", - max_tokens=10, - some_test_param="test-params", - streaming_callback=lambda x: x, - api_base_url="test-base-url", - ) - data = component.to_dict() - assert data == { - "type": "GPTGenerator", - "init_parameters": { - "model_name": "gpt-4", - "system_prompt": "test-system-prompt", - "max_tokens": 10, - "some_test_param": "test-params", - "api_base_url": "test-base-url", - "streaming_callback": "test_gpt_generator.", - }, - } - - @pytest.mark.unit - def test_from_dict(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") - data = { - "type": "GPTGenerator", - "init_parameters": { - "model_name": "gpt-4", - "system_prompt": "test-system-prompt", - "max_tokens": 10, - "some_test_param": "test-params", - "api_base_url": "test-base-url", - "streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback", - }, - } - component = GPTGenerator.from_dict(data) - assert component.system_prompt == "test-system-prompt" - assert component.model_name == "gpt-4" - assert component.streaming_callback == default_streaming_callback - assert component.api_base_url == "test-base-url" - assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} - - @pytest.mark.unit - def test_from_dict_fail_wo_env_var(self, monkeypatch): - openai.api_key = None - monkeypatch.delenv("OPENAI_API_KEY", raising=False) - data = { - "type": "GPTGenerator", - "init_parameters": { - "model_name": "gpt-4", - "system_prompt": "test-system-prompt", - "max_tokens": 10, - "some_test_param": "test-params", - "api_base_url": "test-base-url", - "streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback", - }, - } - with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"): - GPTGenerator.from_dict(data) - - @pytest.mark.unit - def test_run_no_system_prompt(self): - with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch: - gpt_patch.create.side_effect = mock_openai_response - component = GPTGenerator(api_key="test-api-key") - results = component.run(prompt="test-prompt-1") - assert results == { - "replies": ["response for these messages --> user: test-prompt-1"], - "metadata": [ - { - "model": "gpt-3.5-turbo", - "index": "0", - "finish_reason": "stop", - "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, - } - ], - } - gpt_patch.create.assert_called_once_with( - model="gpt-3.5-turbo", messages=[{"role": "user", "content": "test-prompt-1"}], stream=False - ) - - @pytest.mark.unit - def test_run_with_system_prompt(self): - with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch: - gpt_patch.create.side_effect = mock_openai_response - component = GPTGenerator(api_key="test-api-key", system_prompt="test-system-prompt") - results = component.run(prompt="test-prompt-1") - assert results == { - "replies": ["response for these messages --> system: test-system-prompt - user: test-prompt-1"], - "metadata": [ - { - "model": "gpt-3.5-turbo", - "index": "0", - "finish_reason": "stop", - "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, - } - ], - } - gpt_patch.create.assert_called_once_with( - model="gpt-3.5-turbo", - messages=[ - {"role": "system", "content": "test-system-prompt"}, - {"role": "user", "content": "test-prompt-1"}, - ], - stream=False, - ) - - @pytest.mark.unit - def test_run_with_parameters(self): - with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch: - gpt_patch.create.side_effect = mock_openai_response - component = GPTGenerator(api_key="test-api-key", max_tokens=10) - component.run(prompt="test-prompt-1") - gpt_patch.create.assert_called_once_with( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "test-prompt-1"}], - stream=False, - max_tokens=10, - ) - - @pytest.mark.unit - def test_run_stream(self): - with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch: - mock_callback = Mock() - mock_callback.side_effect = default_streaming_callback - gpt_patch.create.side_effect = mock_openai_stream_response - component = GPTGenerator( - api_key="test-api-key", system_prompt="test-system-prompt", streaming_callback=mock_callback - ) - results = component.run(prompt="test-prompt-1") - assert results == { - "replies": ["response for these messages --> system: test-system-prompt - user: test-prompt-1 "], - "metadata": [{"model": "gpt-3.5-turbo", "index": "0", "finish_reason": "stop"}], - } - # Calls count: 10 tokens per prompt + 1 token for the role + 1 empty termination token - assert mock_callback.call_count == 12 - gpt_patch.create.assert_called_once_with( - model="gpt-3.5-turbo", - messages=[ - {"role": "system", "content": "test-system-prompt"}, - {"role": "user", "content": "test-prompt-1"}, - ], - stream=True, - ) - - @pytest.mark.unit - def test_check_truncated_answers(self, caplog): - component = GPTGenerator(api_key="test-api-key") - metadata = [ - {"finish_reason": "stop"}, - {"finish_reason": "content_filter"}, - {"finish_reason": "length"}, - {"finish_reason": "stop"}, - ] - component._check_truncated_answers(metadata) - assert caplog.records[0].message == ( - "2 out of the 4 completions have been truncated before reaching a natural " - "stopping point. Increase the max_tokens parameter to allow for longer completions." - ) - - @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", - ) - @pytest.mark.integration - def test_gpt_generator_run(self): - component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"), n=1) - results = component.run(prompt="What's the capital of France?") - assert len(results["replies"]) == 1 - assert "Paris" in results["replies"][0] - assert len(results["metadata"]) == 1 - assert "gpt-3.5" in results["metadata"][0]["model"] - assert results["metadata"][0]["finish_reason"] == "stop" - - @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", - ) - @pytest.mark.integration - def test_gpt_generator_run_wrong_model_name(self): - component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1) - with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"): - component.run(prompt="What's the capital of France?") - - @pytest.mark.skipif( - not os.environ.get("OPENAI_API_KEY", None), - reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", - ) - @pytest.mark.integration - def test_gpt_generator_run_streaming(self): - class Callback: - def __init__(self): - self.responses = "" - - def __call__(self, chunk): - self.responses += chunk.choices[0].delta.content if chunk.choices[0].delta else "" - return chunk - - callback = Callback() - component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1) - results = component.run(prompt="What's the capital of France?") - - assert len(results["replies"]) == 1 - assert "Paris" in results["replies"][0] - - assert len(results["metadata"]) == 1 - assert "gpt-3.5" in results["metadata"][0]["model"] - assert results["metadata"][0]["finish_reason"] == "stop" - - assert callback.responses == results["replies"][0] diff --git a/test/preview/components/generators/test_openai.py b/test/preview/components/generators/test_openai.py new file mode 100644 index 000000000..7654bd1af --- /dev/null +++ b/test/preview/components/generators/test_openai.py @@ -0,0 +1,348 @@ +import os +from typing import List +from unittest.mock import patch, Mock + +import openai +import pytest + +from haystack.preview.components.generators import GPTGenerator +from haystack.preview.components.generators.utils import default_streaming_callback +from haystack.preview.dataclasses import StreamingChunk, ChatMessage + + +@pytest.fixture +def mock_chat_completion(): + """ + Mock the OpenAI API completion response and reuse it for tests + """ + with patch("openai.ChatCompletion.create", autospec=True) as mock_chat_completion_create: + # mimic the response from the OpenAI API + mock_choice = Mock() + mock_choice.index = 0 + mock_choice.finish_reason = "stop" + + mock_message = Mock() + mock_message.content = "I'm fine, thanks. How are you?" + mock_message.role = "user" + + mock_choice.message = mock_message + + mock_response = Mock() + mock_response.model = "gpt-3.5-turbo" + mock_response.usage = Mock() + mock_response.usage.items.return_value = [ + ("prompt_tokens", 57), + ("completion_tokens", 40), + ("total_tokens", 97), + ] + mock_response.choices = [mock_choice] + mock_chat_completion_create.return_value = mock_response + yield mock_chat_completion_create + + +def streaming_chunk(content: str): + """ + Mock chunks of streaming responses from the OpenAI API + """ + # mimic the chunk response from the OpenAI API + mock_choice = Mock() + mock_choice.index = 0 + mock_choice.delta.content = content + mock_choice.finish_reason = "stop" + + mock_response = Mock() + mock_response.choices = [mock_choice] + mock_response.model = "gpt-3.5-turbo" + mock_response.usage = Mock() + mock_response.usage.items.return_value = [("prompt_tokens", 57), ("completion_tokens", 40), ("total_tokens", 97)] + return mock_response + + +class TestGPTGenerator: + @pytest.mark.unit + def test_init_default(self): + component = GPTGenerator(api_key="test-api-key") + assert openai.api_key == "test-api-key" + assert component.model_name == "gpt-3.5-turbo" + assert component.streaming_callback is None + assert component.api_base_url == "https://api.openai.com/v1" + assert openai.api_base == "https://api.openai.com/v1" + assert not component.generation_kwargs + + @pytest.mark.unit + def test_init_fail_wo_api_key(self, monkeypatch): + openai.api_key = None + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"): + GPTGenerator() + + @pytest.mark.unit + def test_init_with_parameters(self): + component = GPTGenerator( + api_key="test-api-key", + model_name="gpt-4", + max_tokens=10, + some_test_param="test-params", + streaming_callback=default_streaming_callback, + api_base_url="test-base-url", + ) + assert openai.api_key == "test-api-key" + assert component.model_name == "gpt-4" + assert component.streaming_callback is default_streaming_callback + assert component.api_base_url == "test-base-url" + assert openai.api_base == "test-base-url" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + @pytest.mark.unit + def test_to_dict_default(self): + component = GPTGenerator(api_key="test-api-key") + data = component.to_dict() + assert data == { + "type": "GPTGenerator", + "init_parameters": { + "model_name": "gpt-3.5-turbo", + "streaming_callback": None, + "system_prompt": None, + "api_base_url": "https://api.openai.com/v1", + }, + } + + @pytest.mark.unit + def test_to_dict_with_parameters(self): + component = GPTGenerator( + api_key="test-api-key", + model_name="gpt-4", + max_tokens=10, + some_test_param="test-params", + streaming_callback=default_streaming_callback, + api_base_url="test-base-url", + ) + data = component.to_dict() + assert data == { + "type": "GPTGenerator", + "init_parameters": { + "model_name": "gpt-4", + "max_tokens": 10, + "some_test_param": "test-params", + "system_prompt": None, + "api_base_url": "test-base-url", + "streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback", + }, + } + + @pytest.mark.unit + def test_to_dict_with_lambda_streaming_callback(self): + component = GPTGenerator( + api_key="test-api-key", + model_name="gpt-4", + max_tokens=10, + some_test_param="test-params", + streaming_callback=lambda x: x, + api_base_url="test-base-url", + ) + data = component.to_dict() + assert data == { + "type": "GPTGenerator", + "init_parameters": { + "model_name": "gpt-4", + "max_tokens": 10, + "some_test_param": "test-params", + "system_prompt": None, + "api_base_url": "test-base-url", + "streaming_callback": "test_openai.", + }, + } + + @pytest.mark.unit + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") + data = { + "type": "GPTGenerator", + "init_parameters": { + "model_name": "gpt-4", + "max_tokens": 10, + "some_test_param": "test-params", + "api_base_url": "test-base-url", + "system_prompt": None, + "streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback", + }, + } + component = GPTGenerator.from_dict(data) + assert component.model_name == "gpt-4" + assert component.streaming_callback is default_streaming_callback + assert component.api_base_url == "test-base-url" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + @pytest.mark.unit + def test_from_dict_fail_wo_env_var(self, monkeypatch): + openai.api_key = None + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + data = { + "type": "GPTGenerator", + "init_parameters": { + "model_name": "gpt-4", + "max_tokens": 10, + "some_test_param": "test-params", + "api_base_url": "test-base-url", + "streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback", + }, + } + with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"): + GPTGenerator.from_dict(data) + + @pytest.mark.unit + def test_run(self, mock_chat_completion): + component = GPTGenerator(api_key="test-api-key") + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + def test_run_with_params(self, mock_chat_completion): + component = GPTGenerator(api_key="test-api-key", max_tokens=10, temperature=0.5) + response = component.run("What's Natural Language Processing?") + + # check that the component calls the OpenAI API with the correct parameters + _, kwargs = mock_chat_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + @pytest.mark.unit + def test_run_streaming(self, mock_chat_completion): + streaming_call_count = 0 + + # Define the streaming callback function and assert that it is called with StreamingChunk objects + def streaming_callback_fn(chunk: StreamingChunk): + nonlocal streaming_call_count + streaming_call_count += 1 + assert isinstance(chunk, StreamingChunk) + + generator = GPTGenerator(api_key="test-api-key", streaming_callback=streaming_callback_fn) + + # Create a fake streamed response + # self needed here, don't remove + def mock_iter(self): + yield streaming_chunk("Hello") + yield streaming_chunk("How are you?") + + mock_response = Mock(**{"__iter__": mock_iter}) + mock_chat_completion.return_value = mock_response + + response = generator.run("Hello there") + + # Assert that the streaming callback was called twice + assert streaming_call_count == 2 + + # Assert that the response contains the generated replies + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, str) for reply in response["replies"]] + + @pytest.mark.unit + def test_check_abnormal_completions(self, caplog): + component = GPTGenerator(api_key="test-api-key") + + # underlying implementation uses ChatMessage objects so we have to use them here + messages: List[ChatMessage] = [] + for i, _ in enumerate(range(4)): + message = ChatMessage.from_assistant("Hello") + metadata = {"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} + message.metadata.update(metadata) + messages.append(message) + + for m in messages: + component._check_finish_reason(m) + + # check truncation warning + message_template = ( + "The completion for index {index} has been truncated before reaching a natural stopping point. " + "Increase the max_tokens parameter to allow for longer completions." + ) + + for index in [1, 3]: + assert caplog.records[index].message == message_template.format(index=index) + + # check content filter warning + message_template = "The completion for index {index} has been truncated due to the content filter." + for index in [0, 2]: + assert caplog.records[index].message == message_template.format(index=index) + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run(self): + component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")) + results = component.run("What's the capital of France?") + assert len(results["replies"]) == 1 + assert len(results["metadata"]) == 1 + response: str = results["replies"][0] + assert "Paris" in response + + metadata = results["metadata"][0] + assert "gpt-3.5" in metadata["model"] + assert metadata["finish_reason"] == "stop" + + assert "usage" in metadata + assert "prompt_tokens" in metadata["usage"] and metadata["usage"]["prompt_tokens"] > 0 + assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0 + assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0 + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_wrong_model(self): + component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")) + with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"): + component.run("Whatever") + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_streaming(self): + class Callback: + def __init__(self): + self.responses = "" + self.counter = 0 + + def __call__(self, chunk: StreamingChunk) -> None: + self.counter += 1 + self.responses += chunk.content if chunk.content else "" + + callback = Callback() + component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback) + results = component.run("What's the capital of France?") + + assert len(results["replies"]) == 1 + assert len(results["metadata"]) == 1 + response: str = results["replies"][0] + assert "Paris" in response + + metadata = results["metadata"][0] + + assert "gpt-3.5" in metadata["model"] + assert metadata["finish_reason"] == "stop" + + # unfortunately, the usage is not available for streaming calls + # we keep the key in the metadata for compatibility + assert "usage" in metadata and len(metadata["usage"]) == 0 + + assert callback.counter > 1 + assert "Paris" in callback.responses diff --git a/test/preview/components/generators/test_utils.py b/test/preview/components/generators/test_utils.py index 148875700..933950299 100644 --- a/test/preview/components/generators/test_utils.py +++ b/test/preview/components/generators/test_utils.py @@ -1,6 +1,6 @@ import pytest -from haystack.preview.components.generators.openai.gpt import default_streaming_callback +from haystack.preview.components.generators.utils import default_streaming_callback from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler @@ -18,7 +18,7 @@ def test_callback_handler_serialization(): @pytest.mark.unit def test_callback_handler_serialization_non_local(): result = serialize_callback_handler(default_streaming_callback) - assert result == "haystack.preview.components.generators.openai.gpt.default_streaming_callback" + assert result == "haystack.preview.components.generators.utils.default_streaming_callback" @pytest.mark.unit