mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
feat: Adapt GPTGenerator to use str input/output format in Haystack 2.x (#6214)
* Adapt GPTGenerator to string input/output * Finishing touches * punctuation upd * PR feedback * Small naming fixes * Update haystack/preview/components/generators/openai.py Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> * Update class pydoc with a printed response --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
6c5bfe3da4
commit
5497ca2a45
@ -7,7 +7,7 @@ from haystack.preview.document_stores import InMemoryDocumentStore
|
||||
from haystack.preview.components.writers import DocumentWriter
|
||||
from haystack.preview.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
||||
from haystack.preview.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
|
||||
from haystack.preview.components.generators.openai.gpt import GPTGenerator
|
||||
from haystack.preview.components.generators import GPTGenerator
|
||||
from haystack.preview.components.builders.answer_builder import AnswerBuilder
|
||||
from haystack.preview.components.builders.prompt_builder import PromptBuilder
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from haystack.preview.components.generators.openai.gpt import GPTGenerator
|
||||
from haystack.preview.components.generators.hugging_face.hugging_face_local import HuggingFaceLocalGenerator
|
||||
from haystack.preview.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
|
||||
from haystack.preview.components.generators.openai import GPTGenerator
|
||||
|
||||
__all__ = ["GPTGenerator", "HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator"]
|
||||
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator"]
|
||||
|
||||
290
haystack/preview/components/generators/openai.py
Normal file
290
haystack/preview/components/generators/openai.py
Normal file
@ -0,0 +1,290 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, List, Callable, Dict, Any
|
||||
|
||||
import openai
|
||||
from openai.openai_object import OpenAIObject
|
||||
|
||||
from haystack.preview import component, default_from_dict, default_to_dict
|
||||
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.preview.dataclasses import StreamingChunk, ChatMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
API_BASE_URL = "https://api.openai.com/v1"
|
||||
|
||||
|
||||
@component
|
||||
class GPTGenerator:
|
||||
"""
|
||||
Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo
|
||||
family of models.
|
||||
|
||||
Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method
|
||||
directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs`
|
||||
parameter in `run` method.
|
||||
|
||||
For more details on the parameters supported by the OpenAI API, refer to the OpenAI
|
||||
[documentation](https://platform.openai.com/docs/api-reference/chat).
|
||||
|
||||
```python
|
||||
from haystack.preview.components.generators import GPTGenerator
|
||||
client = GPTGenerator()
|
||||
response = client.run("What's Natural Language Processing? Be brief.")
|
||||
print(response)
|
||||
|
||||
>> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
|
||||
>> the interaction between computers and human language. It involves enabling computers to understand, interpret,
|
||||
>> and respond to natural human language in a way that is both meaningful and useful.'], 'metadata': [{'model':
|
||||
>> 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 16,
|
||||
>> 'completion_tokens': 49, 'total_tokens': 65}}]}
|
||||
```
|
||||
|
||||
Key Features and Compatibility:
|
||||
- **Primary Compatibility**: Designed to work seamlessly with gpt-4, gpt-3.5-turbo family of models.
|
||||
- **Streaming Support**: Supports streaming responses from the OpenAI API.
|
||||
- **Customizability**: Supports all parameters supported by the OpenAI API.
|
||||
|
||||
Input and Output Format:
|
||||
- **String Format**: This component uses the strings for both input and output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: str = API_BASE_URL,
|
||||
system_prompt: Optional[str] = None,
|
||||
**generation_kwargs,
|
||||
):
|
||||
"""
|
||||
Creates an instance of ChatGPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
|
||||
GPT-3.5 model.
|
||||
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
environment variable OPENAI_API_KEY (recommended).
|
||||
:param model_name: The name of the model to use.
|
||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||
The callback function accepts StreamingChunk as an argument.
|
||||
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
||||
:param system_prompt: The system prompt to use for text generation. If not provided, the system prompt is
|
||||
omitted, and the default system prompt of the model is used.
|
||||
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
|
||||
the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
|
||||
more details.
|
||||
Some of the supported parameters:
|
||||
- `max_tokens`: The maximum number of tokens the output text can have.
|
||||
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
|
||||
Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
|
||||
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
|
||||
considers the results of the tokens with top_p probability mass. So, 0.1 means only the tokens
|
||||
comprising the top 10% probability mass are considered.
|
||||
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
|
||||
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
|
||||
- `stop`: One or more sequences after which the LLM should stop generating tokens.
|
||||
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
|
||||
the model will be less likely to repeat the same token in the text.
|
||||
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
|
||||
Bigger values mean the model will be less likely to repeat the same token in the text.
|
||||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
|
||||
values are the bias to add to that token.
|
||||
"""
|
||||
# if the user does not provide the API key, check if it is set in the module client
|
||||
api_key = api_key or openai.api_key
|
||||
if api_key is None:
|
||||
try:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"GPTGenerator expects an OpenAI API key. "
|
||||
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
|
||||
) from e
|
||||
openai.api_key = api_key
|
||||
|
||||
self.model_name = model_name
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.system_prompt = system_prompt
|
||||
self.streaming_callback = streaming_callback
|
||||
|
||||
self.api_base_url = api_base_url
|
||||
openai.api_base = api_base_url
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
return {"model": self.model_name}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
:return: The serialized component as a dictionary.
|
||||
"""
|
||||
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
|
||||
return default_to_dict(
|
||||
self,
|
||||
model_name=self.model_name,
|
||||
streaming_callback=callback_name,
|
||||
api_base_url=self.api_base_url,
|
||||
**self.generation_kwargs,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
:param data: The dictionary representation of this component.
|
||||
:return: The deserialized component instance.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
|
||||
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Invoke the text generation inference based on the provided messages and generation parameters.
|
||||
|
||||
:param prompt: The string prompt to use for text generation.
|
||||
:param generation_kwargs: Additional keyword arguments for text generation. These parameters will
|
||||
potentially override the parameters passed in the __init__ method.
|
||||
For more details on the parameters supported by the OpenAI API, refer to the
|
||||
OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create).
|
||||
:return: A list of strings containing the generated responses and a list of dictionaries containing the metadata
|
||||
for each response.
|
||||
"""
|
||||
message = ChatMessage.from_user(prompt)
|
||||
if self.system_prompt:
|
||||
messages = [ChatMessage.from_system(self.system_prompt), message]
|
||||
else:
|
||||
messages = [message]
|
||||
|
||||
# update generation kwargs by merging with the generation kwargs passed to the run method
|
||||
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
|
||||
|
||||
# adapt ChatMessage(s) to the format expected by the OpenAI API
|
||||
openai_formatted_messages = self._convert_to_openai_format(messages)
|
||||
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=self.model_name,
|
||||
messages=openai_formatted_messages,
|
||||
stream=self.streaming_callback is not None,
|
||||
**generation_kwargs,
|
||||
)
|
||||
|
||||
completions: List[ChatMessage]
|
||||
if self.streaming_callback:
|
||||
num_responses = generation_kwargs.pop("n", 1)
|
||||
if num_responses > 1:
|
||||
raise ValueError("Cannot stream multiple responses, please set n=1.")
|
||||
chunks: List[StreamingChunk] = []
|
||||
chunk = None
|
||||
for chunk in completion:
|
||||
if chunk.choices:
|
||||
chunk_delta: StreamingChunk = self._build_chunk(chunk, chunk.choices[0])
|
||||
chunks.append(chunk_delta)
|
||||
self.streaming_callback(chunk_delta) # invoke callback with the chunk_delta
|
||||
completions = [self._connect_chunks(chunk, chunks)]
|
||||
else:
|
||||
completions = [self._build_message(completion, choice) for choice in completion.choices]
|
||||
|
||||
# before returning, do post-processing of the completions
|
||||
for completion in completions:
|
||||
self._check_finish_reason(completion)
|
||||
|
||||
return {
|
||||
"replies": [message.content for message in completions],
|
||||
"metadata": [message.metadata for message in completions],
|
||||
}
|
||||
|
||||
def _convert_to_openai_format(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Converts the list of ChatMessage to the list of messages in the format expected by the OpenAI API.
|
||||
:param messages: The list of ChatMessage.
|
||||
:return: The list of messages in the format expected by the OpenAI API.
|
||||
"""
|
||||
openai_chat_message_format = {"role", "content", "name"}
|
||||
openai_formatted_messages = []
|
||||
for m in messages:
|
||||
message_dict = dataclasses.asdict(m)
|
||||
filtered_message = {k: v for k, v in message_dict.items() if k in openai_chat_message_format and v}
|
||||
openai_formatted_messages.append(filtered_message)
|
||||
return openai_formatted_messages
|
||||
|
||||
def _connect_chunks(self, chunk: OpenAIObject, chunks: List[StreamingChunk]) -> ChatMessage:
|
||||
"""
|
||||
Connects the streaming chunks into a single ChatMessage.
|
||||
"""
|
||||
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks]))
|
||||
complete_response.metadata.update(
|
||||
{
|
||||
"model": chunk.model,
|
||||
"index": 0,
|
||||
"finish_reason": chunk.choices[0].finish_reason,
|
||||
"usage": {}, # we don't have usage data for streaming responses
|
||||
}
|
||||
)
|
||||
return complete_response
|
||||
|
||||
def _build_message(self, completion: OpenAIObject, choice: OpenAIObject) -> ChatMessage:
|
||||
"""
|
||||
Converts the response from the OpenAI API to a ChatMessage.
|
||||
:param completion: The completion returned by the OpenAI API.
|
||||
:param choice: The choice returned by the OpenAI API.
|
||||
:return: The ChatMessage.
|
||||
"""
|
||||
message: OpenAIObject = choice.message
|
||||
content = dict(message.function_call) if choice.finish_reason == "function_call" else message.content
|
||||
chat_message = ChatMessage.from_assistant(content)
|
||||
chat_message.metadata.update(
|
||||
{
|
||||
"model": completion.model,
|
||||
"index": choice.index,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"usage": dict(completion.usage.items()),
|
||||
}
|
||||
)
|
||||
return chat_message
|
||||
|
||||
def _build_chunk(self, chunk: OpenAIObject, choice: OpenAIObject) -> StreamingChunk:
|
||||
"""
|
||||
Converts the response from the OpenAI API to a StreamingChunk.
|
||||
:param chunk: The chunk returned by the OpenAI API.
|
||||
:param choice: The choice returned by the OpenAI API.
|
||||
:return: The StreamingChunk.
|
||||
"""
|
||||
has_content = bool(hasattr(choice.delta, "content") and choice.delta.content)
|
||||
if has_content:
|
||||
content = choice.delta.content
|
||||
elif hasattr(choice.delta, "function_call"):
|
||||
content = str(choice.delta.function_call)
|
||||
else:
|
||||
content = ""
|
||||
chunk_message = StreamingChunk(content)
|
||||
chunk_message.metadata.update(
|
||||
{"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason}
|
||||
)
|
||||
return chunk_message
|
||||
|
||||
def _check_finish_reason(self, message: ChatMessage) -> None:
|
||||
"""
|
||||
Check the `finish_reason` returned with the OpenAI completions.
|
||||
If the `finish_reason` is `length`, log a warning to the user.
|
||||
:param message: The message returned by the LLM.
|
||||
"""
|
||||
if message.metadata["finish_reason"] == "length":
|
||||
logger.warning(
|
||||
"The completion for index %s has been truncated before reaching a natural stopping point. "
|
||||
"Increase the max_tokens parameter to allow for longer completions.",
|
||||
message.metadata["index"],
|
||||
)
|
||||
if message.metadata["finish_reason"] == "content_filter":
|
||||
logger.warning(
|
||||
"The completion for index %s has been truncated due to the content filter.", message.metadata["index"]
|
||||
)
|
||||
@ -1,225 +0,0 @@
|
||||
from typing import Optional, List, Callable, Dict, Any
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, asdict
|
||||
import os
|
||||
|
||||
import openai
|
||||
|
||||
from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
API_BASE_URL = "https://api.openai.com/v1"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ChatMessage:
|
||||
content: str
|
||||
role: str
|
||||
|
||||
|
||||
def default_streaming_callback(chunk):
|
||||
"""
|
||||
Default callback function for streaming responses from OpenAI API.
|
||||
Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged.
|
||||
"""
|
||||
if hasattr(chunk.choices[0].delta, "content"):
|
||||
print(chunk.choices[0].delta.content, flush=True, end="")
|
||||
return chunk
|
||||
|
||||
|
||||
@component
|
||||
class GPTGenerator:
|
||||
"""
|
||||
LLM Generator compatible with GPT (ChatGPT) large language models.
|
||||
|
||||
Queries the LLM using OpenAI's API. Invocations are made using OpenAI SDK ('openai' package)
|
||||
See [OpenAI GPT API](https://platform.openai.com/docs/guides/chat) for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
system_prompt: Optional[str] = None,
|
||||
streaming_callback: Optional[Callable] = None,
|
||||
api_base_url: str = API_BASE_URL,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates an instance of GPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's GPT-3.5 model.
|
||||
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
environment variable OPENAI_API_KEY (recommended).
|
||||
:param model_name: The name of the model to use.
|
||||
:param system_prompt: An additional message to be sent to the LLM at the beginning of each conversation.
|
||||
Typically, a conversation is formatted with a system message first, followed by alternating messages from
|
||||
the 'user' (the "queries") and the 'assistant' (the "responses"). The system message helps set the behavior
|
||||
of the assistant. For example, you can modify the personality of the assistant or provide specific
|
||||
instructions about how it should behave throughout the conversation.
|
||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||
The callback function should accept two parameters: the token received from the stream and **kwargs.
|
||||
The callback function should return the token to be sent to the stream. If the callback function is not
|
||||
provided, the token is printed to stdout.
|
||||
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
||||
:param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI
|
||||
endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details.
|
||||
Some of the supported parameters:
|
||||
- `max_tokens`: The maximum number of tokens the output text can have.
|
||||
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
|
||||
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
|
||||
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
|
||||
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
|
||||
comprising the top 10% probability mass are considered.
|
||||
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
|
||||
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
|
||||
- `stop`: One or more sequences after which the LLM should stop generating tokens.
|
||||
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
|
||||
the model will be less likely to repeat the same token in the text.
|
||||
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
|
||||
Bigger values mean the model will be less likely to repeat the same token in the text.
|
||||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
|
||||
values are the bias to add to that token.
|
||||
"""
|
||||
# if the user does not provide the API key, check if it is set in the module client
|
||||
api_key = api_key or openai.api_key
|
||||
if api_key is None:
|
||||
try:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"GPTGenerator expects an OpenAI API key. "
|
||||
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
|
||||
) from e
|
||||
openai.api_key = api_key
|
||||
|
||||
self.model_name = model_name
|
||||
self.system_prompt = system_prompt
|
||||
self.model_parameters = kwargs
|
||||
self.streaming_callback = streaming_callback
|
||||
|
||||
self.api_base_url = api_base_url
|
||||
openai.api_base = api_base_url
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
return {"model": self.model_name}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
if self.streaming_callback:
|
||||
module = self.streaming_callback.__module__
|
||||
if module == "builtins":
|
||||
callback_name = self.streaming_callback.__name__
|
||||
else:
|
||||
callback_name = f"{module}.{self.streaming_callback.__name__}"
|
||||
else:
|
||||
callback_name = None
|
||||
|
||||
return default_to_dict(
|
||||
self,
|
||||
model_name=self.model_name,
|
||||
system_prompt=self.system_prompt,
|
||||
streaming_callback=callback_name,
|
||||
api_base_url=self.api_base_url,
|
||||
**self.model_parameters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
init_params = data.get("init_parameters", {})
|
||||
streaming_callback = None
|
||||
if "streaming_callback" in init_params and init_params["streaming_callback"]:
|
||||
parts = init_params["streaming_callback"].split(".")
|
||||
module_name = ".".join(parts[:-1])
|
||||
function_name = parts[-1]
|
||||
module = sys.modules.get(module_name, None)
|
||||
if not module:
|
||||
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}")
|
||||
streaming_callback = getattr(module, function_name, None)
|
||||
if not streaming_callback:
|
||||
raise DeserializationError(f"Could not locate the streaming callback: {function_name}")
|
||||
data["init_parameters"]["streaming_callback"] = streaming_callback
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(replies=List[str], metadata=List[Dict[str, Any]])
|
||||
def run(self, prompt: str):
|
||||
"""
|
||||
Queries the LLM with the prompts to produce replies.
|
||||
|
||||
:param prompts: The prompts to be sent to the generative model.
|
||||
"""
|
||||
message = _ChatMessage(content=prompt, role="user")
|
||||
if self.system_prompt:
|
||||
chat = [_ChatMessage(content=self.system_prompt, role="system"), message]
|
||||
else:
|
||||
chat = [message]
|
||||
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=self.model_name,
|
||||
messages=[asdict(message) for message in chat],
|
||||
stream=self.streaming_callback is not None,
|
||||
**self.model_parameters,
|
||||
)
|
||||
|
||||
replies: List[str]
|
||||
metadata: List[Dict[str, Any]]
|
||||
if self.streaming_callback:
|
||||
replies_dict: Dict[str, str] = defaultdict(str)
|
||||
metadata_dict: Dict[str, Dict[str, Any]] = defaultdict(dict)
|
||||
for chunk in completion:
|
||||
chunk = self.streaming_callback(chunk)
|
||||
for choice in chunk.choices:
|
||||
if hasattr(choice.delta, "content"):
|
||||
replies_dict[choice.index] += choice.delta.content
|
||||
metadata_dict[choice.index] = {
|
||||
"model": chunk.model,
|
||||
"index": choice.index,
|
||||
"finish_reason": choice.finish_reason,
|
||||
}
|
||||
replies = list(replies_dict.values())
|
||||
metadata = list(metadata_dict.values())
|
||||
self._check_truncated_answers(metadata)
|
||||
return {"replies": replies, "metadata": metadata}
|
||||
|
||||
metadata = [
|
||||
{
|
||||
"model": completion.model,
|
||||
"index": choice.index,
|
||||
"finish_reason": choice.finish_reason,
|
||||
"usage": dict(completion.usage.items()),
|
||||
}
|
||||
for choice in completion.choices
|
||||
]
|
||||
replies = [choice.message.content.strip() for choice in completion.choices]
|
||||
self._check_truncated_answers(metadata)
|
||||
return {"replies": replies, "metadata": metadata}
|
||||
|
||||
def _check_truncated_answers(self, metadata: List[Dict[str, Any]]):
|
||||
"""
|
||||
Check the `finish_reason` returned with the OpenAI completions.
|
||||
If the `finish_reason` is `length`, log a warning to the user.
|
||||
|
||||
:param result: The result returned from the OpenAI API.
|
||||
:param payload: The payload sent to the OpenAI API.
|
||||
"""
|
||||
truncated_completions = sum(1 for meta in metadata if meta.get("finish_reason") != "stop")
|
||||
if truncated_completions > 0:
|
||||
logger.warning(
|
||||
"%s out of the %s completions have been truncated before reaching a natural stopping point. "
|
||||
"Increase the max_tokens parameter to allow for longer completions.",
|
||||
truncated_completions,
|
||||
len(metadata),
|
||||
)
|
||||
@ -6,6 +6,14 @@ from haystack.preview import DeserializationError
|
||||
from haystack.preview.dataclasses import StreamingChunk
|
||||
|
||||
|
||||
def default_streaming_callback(chunk: StreamingChunk) -> None:
|
||||
"""
|
||||
Default callback function for streaming responses.
|
||||
Prints the tokens of the first completion to stdout as soon as they are received
|
||||
"""
|
||||
print(chunk.content, flush=True, end="")
|
||||
|
||||
|
||||
def serialize_callback_handler(streaming_callback: Callable[[StreamingChunk], None]) -> str:
|
||||
"""
|
||||
Serializes the streaming callback handler.
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Adapt GPTGenerator to use strings for input and output
|
||||
@ -1,334 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import patch, Mock
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import openai
|
||||
from openai.util import convert_to_openai_object
|
||||
|
||||
from haystack.preview.components.generators.openai.gpt import GPTGenerator
|
||||
from haystack.preview.components.generators.openai.gpt import default_streaming_callback
|
||||
|
||||
|
||||
def mock_openai_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion:
|
||||
response = f"response for these messages --> {' - '.join(msg['role']+': '+msg['content'] for msg in messages)}"
|
||||
base_dict = {
|
||||
"id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V",
|
||||
"object": "chat.completion",
|
||||
"created": 1685855844,
|
||||
"model": model,
|
||||
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
|
||||
}
|
||||
base_dict["choices"] = [
|
||||
{"message": {"role": "assistant", "content": response}, "finish_reason": "stop", "index": "0"}
|
||||
]
|
||||
return convert_to_openai_object(deepcopy(base_dict))
|
||||
|
||||
|
||||
def mock_openai_stream_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion:
|
||||
response = f"response for these messages --> {' - '.join(msg['role']+': '+msg['content'] for msg in messages)}"
|
||||
base_dict = {
|
||||
"id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V",
|
||||
"object": "chat.completion",
|
||||
"created": 1685855844,
|
||||
"model": model,
|
||||
}
|
||||
base_dict["choices"] = [{"delta": {"role": "assistant"}, "finish_reason": None, "index": "0"}]
|
||||
yield convert_to_openai_object(base_dict)
|
||||
for token in response.split():
|
||||
base_dict["choices"][0]["delta"] = {"content": token + " "}
|
||||
yield convert_to_openai_object(base_dict)
|
||||
base_dict["choices"] = [{"delta": {"content": ""}, "finish_reason": "stop", "index": "0"}]
|
||||
yield convert_to_openai_object(base_dict)
|
||||
|
||||
|
||||
class TestGPTGenerator:
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
assert openai.api_key == "test-api-key"
|
||||
assert component.system_prompt is None
|
||||
assert component.model_name == "gpt-3.5-turbo"
|
||||
assert component.streaming_callback is None
|
||||
assert component.api_base_url == "https://api.openai.com/v1"
|
||||
assert openai.api_base == "https://api.openai.com/v1"
|
||||
assert component.model_parameters == {}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
openai.api_key = None
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"):
|
||||
GPTGenerator()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
callback = lambda x: x
|
||||
component = GPTGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
system_prompt="test-system-prompt",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=callback,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
assert openai.api_key == "test-api-key"
|
||||
assert component.system_prompt == "test-system-prompt"
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.streaming_callback == callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert openai.api_base == "test-base-url"
|
||||
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_default(self):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "GPTGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"system_prompt": None,
|
||||
"streaming_callback": None,
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_parameters(self):
|
||||
component = GPTGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
system_prompt="test-system-prompt",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=default_streaming_callback,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "GPTGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"system_prompt": "test-system-prompt",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_lambda_streaming_callback(self):
|
||||
component = GPTGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
system_prompt="test-system-prompt",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=lambda x: x,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "GPTGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"system_prompt": "test-system-prompt",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "test_gpt_generator.<lambda>",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||
data = {
|
||||
"type": "GPTGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"system_prompt": "test-system-prompt",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback",
|
||||
},
|
||||
}
|
||||
component = GPTGenerator.from_dict(data)
|
||||
assert component.system_prompt == "test-system-prompt"
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.streaming_callback == default_streaming_callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_fail_wo_env_var(self, monkeypatch):
|
||||
openai.api_key = None
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
data = {
|
||||
"type": "GPTGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"system_prompt": "test-system-prompt",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback",
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"):
|
||||
GPTGenerator.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_no_system_prompt(self):
|
||||
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
|
||||
gpt_patch.create.side_effect = mock_openai_response
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
results = component.run(prompt="test-prompt-1")
|
||||
assert results == {
|
||||
"replies": ["response for these messages --> user: test-prompt-1"],
|
||||
"metadata": [
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"index": "0",
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
|
||||
}
|
||||
],
|
||||
}
|
||||
gpt_patch.create.assert_called_once_with(
|
||||
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "test-prompt-1"}], stream=False
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_system_prompt(self):
|
||||
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
|
||||
gpt_patch.create.side_effect = mock_openai_response
|
||||
component = GPTGenerator(api_key="test-api-key", system_prompt="test-system-prompt")
|
||||
results = component.run(prompt="test-prompt-1")
|
||||
assert results == {
|
||||
"replies": ["response for these messages --> system: test-system-prompt - user: test-prompt-1"],
|
||||
"metadata": [
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"index": "0",
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
|
||||
}
|
||||
],
|
||||
}
|
||||
gpt_patch.create.assert_called_once_with(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "system", "content": "test-system-prompt"},
|
||||
{"role": "user", "content": "test-prompt-1"},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_parameters(self):
|
||||
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
|
||||
gpt_patch.create.side_effect = mock_openai_response
|
||||
component = GPTGenerator(api_key="test-api-key", max_tokens=10)
|
||||
component.run(prompt="test-prompt-1")
|
||||
gpt_patch.create.assert_called_once_with(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "test-prompt-1"}],
|
||||
stream=False,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_stream(self):
|
||||
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
|
||||
mock_callback = Mock()
|
||||
mock_callback.side_effect = default_streaming_callback
|
||||
gpt_patch.create.side_effect = mock_openai_stream_response
|
||||
component = GPTGenerator(
|
||||
api_key="test-api-key", system_prompt="test-system-prompt", streaming_callback=mock_callback
|
||||
)
|
||||
results = component.run(prompt="test-prompt-1")
|
||||
assert results == {
|
||||
"replies": ["response for these messages --> system: test-system-prompt - user: test-prompt-1 "],
|
||||
"metadata": [{"model": "gpt-3.5-turbo", "index": "0", "finish_reason": "stop"}],
|
||||
}
|
||||
# Calls count: 10 tokens per prompt + 1 token for the role + 1 empty termination token
|
||||
assert mock_callback.call_count == 12
|
||||
gpt_patch.create.assert_called_once_with(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{"role": "system", "content": "test-system-prompt"},
|
||||
{"role": "user", "content": "test-prompt-1"},
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_check_truncated_answers(self, caplog):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
metadata = [
|
||||
{"finish_reason": "stop"},
|
||||
{"finish_reason": "content_filter"},
|
||||
{"finish_reason": "length"},
|
||||
{"finish_reason": "stop"},
|
||||
]
|
||||
component._check_truncated_answers(metadata)
|
||||
assert caplog.records[0].message == (
|
||||
"2 out of the 4 completions have been truncated before reaching a natural "
|
||||
"stopping point. Increase the max_tokens parameter to allow for longer completions."
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_gpt_generator_run(self):
|
||||
component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
|
||||
results = component.run(prompt="What's the capital of France?")
|
||||
assert len(results["replies"]) == 1
|
||||
assert "Paris" in results["replies"][0]
|
||||
assert len(results["metadata"]) == 1
|
||||
assert "gpt-3.5" in results["metadata"][0]["model"]
|
||||
assert results["metadata"][0]["finish_reason"] == "stop"
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_gpt_generator_run_wrong_model_name(self):
|
||||
component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1)
|
||||
with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"):
|
||||
component.run(prompt="What's the capital of France?")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_gpt_generator_run_streaming(self):
|
||||
class Callback:
|
||||
def __init__(self):
|
||||
self.responses = ""
|
||||
|
||||
def __call__(self, chunk):
|
||||
self.responses += chunk.choices[0].delta.content if chunk.choices[0].delta else ""
|
||||
return chunk
|
||||
|
||||
callback = Callback()
|
||||
component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1)
|
||||
results = component.run(prompt="What's the capital of France?")
|
||||
|
||||
assert len(results["replies"]) == 1
|
||||
assert "Paris" in results["replies"][0]
|
||||
|
||||
assert len(results["metadata"]) == 1
|
||||
assert "gpt-3.5" in results["metadata"][0]["model"]
|
||||
assert results["metadata"][0]["finish_reason"] == "stop"
|
||||
|
||||
assert callback.responses == results["replies"][0]
|
||||
348
test/preview/components/generators/test_openai.py
Normal file
348
test/preview/components/generators/test_openai.py
Normal file
@ -0,0 +1,348 @@
|
||||
import os
|
||||
from typing import List
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from haystack.preview.components.generators import GPTGenerator
|
||||
from haystack.preview.components.generators.utils import default_streaming_callback
|
||||
from haystack.preview.dataclasses import StreamingChunk, ChatMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chat_completion():
|
||||
"""
|
||||
Mock the OpenAI API completion response and reuse it for tests
|
||||
"""
|
||||
with patch("openai.ChatCompletion.create", autospec=True) as mock_chat_completion_create:
|
||||
# mimic the response from the OpenAI API
|
||||
mock_choice = Mock()
|
||||
mock_choice.index = 0
|
||||
mock_choice.finish_reason = "stop"
|
||||
|
||||
mock_message = Mock()
|
||||
mock_message.content = "I'm fine, thanks. How are you?"
|
||||
mock_message.role = "user"
|
||||
|
||||
mock_choice.message = mock_message
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.model = "gpt-3.5-turbo"
|
||||
mock_response.usage = Mock()
|
||||
mock_response.usage.items.return_value = [
|
||||
("prompt_tokens", 57),
|
||||
("completion_tokens", 40),
|
||||
("total_tokens", 97),
|
||||
]
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_chat_completion_create.return_value = mock_response
|
||||
yield mock_chat_completion_create
|
||||
|
||||
|
||||
def streaming_chunk(content: str):
|
||||
"""
|
||||
Mock chunks of streaming responses from the OpenAI API
|
||||
"""
|
||||
# mimic the chunk response from the OpenAI API
|
||||
mock_choice = Mock()
|
||||
mock_choice.index = 0
|
||||
mock_choice.delta.content = content
|
||||
mock_choice.finish_reason = "stop"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.model = "gpt-3.5-turbo"
|
||||
mock_response.usage = Mock()
|
||||
mock_response.usage.items.return_value = [("prompt_tokens", 57), ("completion_tokens", 40), ("total_tokens", 97)]
|
||||
return mock_response
|
||||
|
||||
|
||||
class TestGPTGenerator:
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
assert openai.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-3.5-turbo"
|
||||
assert component.streaming_callback is None
|
||||
assert component.api_base_url == "https://api.openai.com/v1"
|
||||
assert openai.api_base == "https://api.openai.com/v1"
|
||||
assert not component.generation_kwargs
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
openai.api_key = None
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"):
|
||||
GPTGenerator()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
component = GPTGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=default_streaming_callback,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
assert openai.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.streaming_callback is default_streaming_callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert openai.api_base == "test-base-url"
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_default(self):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "GPTGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"streaming_callback": None,
|
||||
"system_prompt": None,
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_parameters(self):
|
||||
component = GPTGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=default_streaming_callback,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "GPTGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"system_prompt": None,
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_lambda_streaming_callback(self):
|
||||
component = GPTGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=lambda x: x,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "GPTGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"system_prompt": None,
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "test_openai.<lambda>",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||
data = {
|
||||
"type": "GPTGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"api_base_url": "test-base-url",
|
||||
"system_prompt": None,
|
||||
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
|
||||
},
|
||||
}
|
||||
component = GPTGenerator.from_dict(data)
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.streaming_callback is default_streaming_callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_fail_wo_env_var(self, monkeypatch):
|
||||
openai.api_key = None
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
data = {
|
||||
"type": "GPTGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.preview.components.generators.utils.default_streaming_callback",
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError, match="GPTGenerator expects an OpenAI API key"):
|
||||
GPTGenerator.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self, mock_chat_completion):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
response = component.run("What's Natural Language Processing?")
|
||||
|
||||
# check that the component returns the correct ChatMessage response
|
||||
assert isinstance(response, dict)
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) == 1
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
def test_run_with_params(self, mock_chat_completion):
|
||||
component = GPTGenerator(api_key="test-api-key", max_tokens=10, temperature=0.5)
|
||||
response = component.run("What's Natural Language Processing?")
|
||||
|
||||
# check that the component calls the OpenAI API with the correct parameters
|
||||
_, kwargs = mock_chat_completion.call_args
|
||||
assert kwargs["max_tokens"] == 10
|
||||
assert kwargs["temperature"] == 0.5
|
||||
|
||||
# check that the component returns the correct response
|
||||
assert isinstance(response, dict)
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) == 1
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_streaming(self, mock_chat_completion):
|
||||
streaming_call_count = 0
|
||||
|
||||
# Define the streaming callback function and assert that it is called with StreamingChunk objects
|
||||
def streaming_callback_fn(chunk: StreamingChunk):
|
||||
nonlocal streaming_call_count
|
||||
streaming_call_count += 1
|
||||
assert isinstance(chunk, StreamingChunk)
|
||||
|
||||
generator = GPTGenerator(api_key="test-api-key", streaming_callback=streaming_callback_fn)
|
||||
|
||||
# Create a fake streamed response
|
||||
# self needed here, don't remove
|
||||
def mock_iter(self):
|
||||
yield streaming_chunk("Hello")
|
||||
yield streaming_chunk("How are you?")
|
||||
|
||||
mock_response = Mock(**{"__iter__": mock_iter})
|
||||
mock_chat_completion.return_value = mock_response
|
||||
|
||||
response = generator.run("Hello there")
|
||||
|
||||
# Assert that the streaming callback was called twice
|
||||
assert streaming_call_count == 2
|
||||
|
||||
# Assert that the response contains the generated replies
|
||||
assert "replies" in response
|
||||
assert isinstance(response["replies"], list)
|
||||
assert len(response["replies"]) > 0
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_check_abnormal_completions(self, caplog):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
|
||||
# underlying implementation uses ChatMessage objects so we have to use them here
|
||||
messages: List[ChatMessage] = []
|
||||
for i, _ in enumerate(range(4)):
|
||||
message = ChatMessage.from_assistant("Hello")
|
||||
metadata = {"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
|
||||
message.metadata.update(metadata)
|
||||
messages.append(message)
|
||||
|
||||
for m in messages:
|
||||
component._check_finish_reason(m)
|
||||
|
||||
# check truncation warning
|
||||
message_template = (
|
||||
"The completion for index {index} has been truncated before reaching a natural stopping point. "
|
||||
"Increase the max_tokens parameter to allow for longer completions."
|
||||
)
|
||||
|
||||
for index in [1, 3]:
|
||||
assert caplog.records[index].message == message_template.format(index=index)
|
||||
|
||||
# check content filter warning
|
||||
message_template = "The completion for index {index} has been truncated due to the content filter."
|
||||
for index in [0, 2]:
|
||||
assert caplog.records[index].message == message_template.format(index=index)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run(self):
|
||||
component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
results = component.run("What's the capital of France?")
|
||||
assert len(results["replies"]) == 1
|
||||
assert len(results["metadata"]) == 1
|
||||
response: str = results["replies"][0]
|
||||
assert "Paris" in response
|
||||
|
||||
metadata = results["metadata"][0]
|
||||
assert "gpt-3.5" in metadata["model"]
|
||||
assert metadata["finish_reason"] == "stop"
|
||||
|
||||
assert "usage" in metadata
|
||||
assert "prompt_tokens" in metadata["usage"] and metadata["usage"]["prompt_tokens"] > 0
|
||||
assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0
|
||||
assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_wrong_model(self):
|
||||
component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"):
|
||||
component.run("Whatever")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_streaming(self):
|
||||
class Callback:
|
||||
def __init__(self):
|
||||
self.responses = ""
|
||||
self.counter = 0
|
||||
|
||||
def __call__(self, chunk: StreamingChunk) -> None:
|
||||
self.counter += 1
|
||||
self.responses += chunk.content if chunk.content else ""
|
||||
|
||||
callback = Callback()
|
||||
component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
|
||||
results = component.run("What's the capital of France?")
|
||||
|
||||
assert len(results["replies"]) == 1
|
||||
assert len(results["metadata"]) == 1
|
||||
response: str = results["replies"][0]
|
||||
assert "Paris" in response
|
||||
|
||||
metadata = results["metadata"][0]
|
||||
|
||||
assert "gpt-3.5" in metadata["model"]
|
||||
assert metadata["finish_reason"] == "stop"
|
||||
|
||||
# unfortunately, the usage is not available for streaming calls
|
||||
# we keep the key in the metadata for compatibility
|
||||
assert "usage" in metadata and len(metadata["usage"]) == 0
|
||||
|
||||
assert callback.counter > 1
|
||||
assert "Paris" in callback.responses
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from haystack.preview.components.generators.openai.gpt import default_streaming_callback
|
||||
from haystack.preview.components.generators.utils import default_streaming_callback
|
||||
from haystack.preview.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ def test_callback_handler_serialization():
|
||||
@pytest.mark.unit
|
||||
def test_callback_handler_serialization_non_local():
|
||||
result = serialize_callback_handler(default_streaming_callback)
|
||||
assert result == "haystack.preview.components.generators.openai.gpt.default_streaming_callback"
|
||||
assert result == "haystack.preview.components.generators.utils.default_streaming_callback"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user