feat: GPT35Generator (#5714)

* chatgpt backend

* fix tests

* reno

* remove print

* helpers tests

* add chatgpt generator

* use openai sdk

* remove backend

* tests are broken

* fix tests

* stray param

* move _check_troncated_answers into the class

* wrong import

* rename function

* typo in test

* add openai deps

* mypy

* improve system prompt docstring

* typos update

* Update haystack/preview/components/generators/openai/chatgpt.py

* pylint

* Update haystack/preview/components/generators/openai/chatgpt.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Update haystack/preview/components/generators/openai/chatgpt.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Update haystack/preview/components/generators/openai/chatgpt.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* review feedback

* fix tests

* freview feedback

* reno

* remove tenacity mock

* gpt35generator

* fix naming

* remove stray references to chatgpt

* fix e2e

* Update releasenotes/notes/chatgpt-llm-generator-d043532654efe684.yaml

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* add another test

* test wrong model name

* review feedback

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
ZanSara 2023-09-07 09:06:57 +01:00 committed by GitHub
parent c5edb45c10
commit 63cbde7287
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 635 additions and 54 deletions

View File

@ -0,0 +1,86 @@
import os
import pytest
import openai
from haystack.preview.components.generators.openai.gpt35 import GPT35Generator
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_gpt35_generator_run():
component = GPT35Generator(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
results = component.run(prompts=["What's the capital of France?", "What's the capital of Germany?"])
assert len(results["replies"]) == 2
assert len(results["replies"][0]) == 1
assert "Paris" in results["replies"][0][0]
assert len(results["replies"][1]) == 1
assert "Berlin" in results["replies"][1][0]
assert len(results["metadata"]) == 2
assert len(results["metadata"][0]) == 1
assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"]
assert "stop" == results["metadata"][0][0]["finish_reason"]
assert len(results["metadata"][1]) == 1
assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"]
assert "stop" == results["metadata"][1][0]["finish_reason"]
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_gpt35_generator_run_wrong_model_name():
component = GPT35Generator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1)
with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"):
component.run(prompts=["What's the capital of France?"])
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_gpt35_generator_run_above_context_length():
component = GPT35Generator(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
with pytest.raises(
openai.InvalidRequestError,
match="This model's maximum context length is 4097 tokens. However, your messages resulted in 70008 tokens. "
"Please reduce the length of the messages.",
):
component.run(prompts=["What's the capital of France? " * 10_000])
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_gpt35_generator_run_streaming():
class Callback:
def __init__(self):
self.responses = ""
def __call__(self, chunk):
self.responses += chunk.choices[0].delta.content if chunk.choices[0].delta else ""
return chunk
callback = Callback()
component = GPT35Generator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1)
results = component.run(prompts=["What's the capital of France?", "What's the capital of Germany?"])
assert len(results["replies"]) == 2
assert len(results["replies"][0]) == 1
assert "Paris" in results["replies"][0][0]
assert len(results["replies"][1]) == 1
assert "Berlin" in results["replies"][1][0]
assert callback.responses == results["replies"][0][0] + results["replies"][1][0]
assert len(results["metadata"]) == 2
assert len(results["metadata"][0]) == 1
assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"]
assert "stop" == results["metadata"][0][0]["finish_reason"]
assert len(results["metadata"][1]) == 1
assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"]
assert "stop" == results["metadata"][1][0]["finish_reason"]

View File

@ -1,33 +0,0 @@
import logging
from haystack.preview.lazy_imports import LazyImport
with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
import tiktoken
logger = logging.getLogger(__name__)
def enforce_token_limit(prompt: str, tokenizer: "tiktoken.Encoding", max_tokens_limit: int) -> str:
"""
Ensure that the length of the prompt is within the max tokens limit of the model.
If needed, truncate the prompt text so that it fits within the limit.
:param prompt: Prompt text to be sent to the generative model.
:param tokenizer: The tokenizer used to encode the prompt.
:param max_tokens_limit: The max tokens limit of the model.
:return: The prompt text that fits within the max tokens limit of the model.
"""
tiktoken_import.check()
tokens = tokenizer.encode(prompt)
tokens_count = len(tokens)
if tokens_count > max_tokens_limit:
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens to fit within the max token limit. "
"Reduce the length of the prompt to prevent it from being cut off.",
tokens_count,
max_tokens_limit,
)
prompt = tokenizer.decode(tokens[:max_tokens_limit])
return prompt

View File

@ -0,0 +1,213 @@
from typing import Optional, List, Callable, Dict, Any
import sys
import logging
from dataclasses import dataclass, asdict
import openai
from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError
logger = logging.getLogger(__name__)
@dataclass
class _ChatMessage:
content: str
role: str
def default_streaming_callback(chunk):
"""
Default callback function for streaming responses from OpenAI API.
Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged.
"""
if hasattr(chunk.choices[0].delta, "content"):
print(chunk.choices[0].delta.content, flush=True, end="")
return chunk
@component
class GPT35Generator:
"""
LLM Generator compatible with GPT3.5 (ChatGPT) large language models.
Queries the LLM using OpenAI's API. Invocations are made using OpenAI SDK ('openai' package)
See [OpenAI GPT3.5 API](https://platform.openai.com/docs/guides/chat) for more details.
"""
def __init__(
self,
api_key: str,
model_name: str = "gpt-3.5-turbo",
system_prompt: Optional[str] = None,
streaming_callback: Optional[Callable] = None,
api_base_url: str = "https://api.openai.com/v1",
**kwargs,
):
"""
Creates an instance of GPT35Generator for OpenAI's GPT-3.5 model.
:param api_key: The OpenAI API key.
:param model_name: The name of the model to use.
:param system_prompt: An additional message to be sent to the LLM at the beginning of each conversation.
Typically, a conversation is formatted with a system message first, followed by alternating messages from
the 'user' (the "queries") and the 'assistant' (the "responses"). The system message helps set the behavior
of the assistant. For example, you can modify the personality of the assistant or provide specific
instructions about how it should behave throughout the conversation.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
provided, the token is printed to stdout.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI
endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details.
Some of the supported parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
values are the bias to add to that token.
"""
self.api_key = api_key
self.model_name = model_name
self.system_prompt = system_prompt
self.model_parameters = kwargs
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
if self.streaming_callback:
module = self.streaming_callback.__module__
if module == "builtins":
callback_name = self.streaming_callback.__name__
else:
callback_name = f"{module}.{self.streaming_callback.__name__}"
else:
callback_name = None
return default_to_dict(
self,
api_key=self.api_key,
model_name=self.model_name,
system_prompt=self.system_prompt,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.model_parameters,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GPT35Generator":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
streaming_callback = None
if "streaming_callback" in init_params:
parts = init_params["streaming_callback"].split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}")
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
raise DeserializationError(f"Could not locate the streaming callback: {function_name}")
data["init_parameters"]["streaming_callback"] = streaming_callback
return default_from_dict(cls, data)
@component.output_types(replies=List[List[str]], metadata=List[Dict[str, Any]])
def run(self, prompts: List[str]):
"""
Queries the LLM with the prompts to produce replies.
:param prompts: The prompts to be sent to the generative model.
"""
chats = []
for prompt in prompts:
message = _ChatMessage(content=prompt, role="user")
if self.system_prompt:
chats.append([_ChatMessage(content=self.system_prompt, role="system"), message])
else:
chats.append([message])
all_replies, all_metadata = [], []
for chat in chats:
completion = openai.ChatCompletion.create(
model=self.model_name,
api_key=self.api_key,
messages=[asdict(message) for message in chat],
stream=self.streaming_callback is not None,
**self.model_parameters,
)
replies: List[str]
metadata: List[Dict[str, Any]]
if self.streaming_callback:
replies_dict = {}
metadata_dict: Dict[str, Dict[str, Any]] = {}
for chunk in completion:
chunk = self.streaming_callback(chunk)
for choice in chunk.choices:
if choice.index not in replies_dict:
replies_dict[choice.index] = ""
metadata_dict[choice.index] = {}
if hasattr(choice.delta, "content"):
replies_dict[choice.index] += choice.delta.content
metadata_dict[choice.index] = {
"model": chunk.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
}
all_replies.append(list(replies_dict.values()))
all_metadata.append(list(metadata_dict.values()))
self._check_truncated_answers(list(metadata_dict.values()))
else:
metadata = [
{
"model": completion.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"usage": dict(completion.usage.items()),
}
for choice in completion.choices
]
replies = [choice.message.content.strip() for choice in completion.choices]
all_replies.append(replies)
all_metadata.append(metadata)
self._check_truncated_answers(metadata)
return {"replies": all_replies, "metadata": all_metadata}
def _check_truncated_answers(self, metadata: List[Dict[str, Any]]):
"""
Check the `finish_reason` returned with the OpenAI completions.
If the `finish_reason` is `length`, log a warning to the user.
:param result: The result returned from the OpenAI API.
:param payload: The payload sent to the OpenAI API.
"""
truncated_completions = sum(1 for meta in metadata if meta.get("finish_reason") != "stop")
if truncated_completions > 0:
logger.warning(
"%s out of the %s completions have been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions.",
truncated_completions,
len(metadata),
)

View File

@ -80,6 +80,7 @@ dependencies = [
# Preview
"canals==0.8.0",
"openai",
"Jinja2",
"openai-whisper", # FIXME https://github.com/deepset-ai/haystack/issues/5731

View File

@ -0,0 +1,2 @@
preview:
- Introduce `GPT35Generator`, a class that can generate completions using OpenAI Chat models like GPT3.5 and GPT4.

View File

@ -0,0 +1,332 @@
from unittest.mock import patch, Mock
from copy import deepcopy
import pytest
import openai
from openai.util import convert_to_openai_object
from haystack.preview.components.generators.openai.gpt35 import GPT35Generator
from haystack.preview.components.generators.openai.gpt35 import default_streaming_callback
def mock_openai_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion:
response = f"response for these messages --> {' - '.join(msg['role']+': '+msg['content'] for msg in messages)}"
base_dict = {
"id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V",
"object": "chat.completion",
"created": 1685855844,
"model": model,
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
}
base_dict["choices"] = [
{"message": {"role": "assistant", "content": response}, "finish_reason": "stop", "index": "0"}
]
return convert_to_openai_object(deepcopy(base_dict))
def mock_openai_stream_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion:
response = f"response for these messages --> {' - '.join(msg['role']+': '+msg['content'] for msg in messages)}"
base_dict = {
"id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V",
"object": "chat.completion",
"created": 1685855844,
"model": model,
}
base_dict["choices"] = [{"delta": {"role": "assistant"}, "finish_reason": None, "index": "0"}]
yield convert_to_openai_object(base_dict)
for token in response.split():
base_dict["choices"][0]["delta"] = {"content": token + " "}
yield convert_to_openai_object(base_dict)
base_dict["choices"] = [{"delta": {"content": ""}, "finish_reason": "stop", "index": "0"}]
yield convert_to_openai_object(base_dict)
class TestGPT35Generator:
@pytest.mark.unit
def test_init_default(self):
component = GPT35Generator(api_key="test-api-key")
assert component.system_prompt is None
assert component.api_key == "test-api-key"
assert component.model_name == "gpt-3.5-turbo"
assert component.streaming_callback is None
assert component.api_base_url == "https://api.openai.com/v1"
assert component.model_parameters == {}
@pytest.mark.unit
def test_init_with_parameters(self):
callback = lambda x: x
component = GPT35Generator(
api_key="test-api-key",
model_name="gpt-4",
system_prompt="test-system-prompt",
max_tokens=10,
some_test_param="test-params",
streaming_callback=callback,
api_base_url="test-base-url",
)
assert component.system_prompt == "test-system-prompt"
assert component.api_key == "test-api-key"
assert component.model_name == "gpt-4"
assert component.streaming_callback == callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
@pytest.mark.unit
def test_to_dict_default(self):
component = GPT35Generator(api_key="test-api-key")
data = component.to_dict()
assert data == {
"type": "GPT35Generator",
"init_parameters": {
"api_key": "test-api-key",
"model_name": "gpt-3.5-turbo",
"system_prompt": None,
"streaming_callback": None,
"api_base_url": "https://api.openai.com/v1",
},
}
@pytest.mark.unit
def test_to_dict_with_parameters(self):
component = GPT35Generator(
api_key="test-api-key",
model_name="gpt-4",
system_prompt="test-system-prompt",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "GPT35Generator",
"init_parameters": {
"api_key": "test-api-key",
"model_name": "gpt-4",
"system_prompt": "test-system-prompt",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback",
},
}
@pytest.mark.unit
def test_to_dict_with_lambda_streaming_callback(self):
component = GPT35Generator(
api_key="test-api-key",
model_name="gpt-4",
system_prompt="test-system-prompt",
max_tokens=10,
some_test_param="test-params",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "GPT35Generator",
"init_parameters": {
"api_key": "test-api-key",
"model_name": "gpt-4",
"system_prompt": "test-system-prompt",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "test_gpt35_generator.<lambda>",
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "GPT35Generator",
"init_parameters": {
"api_key": "test-api-key",
"model_name": "gpt-4",
"system_prompt": "test-system-prompt",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback",
},
}
component = GPT35Generator.from_dict(data)
assert component.system_prompt == "test-system-prompt"
assert component.api_key == "test-api-key"
assert component.model_name == "gpt-4"
assert component.streaming_callback == default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
@pytest.mark.unit
def test_run_no_system_prompt(self):
with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch:
gpt35_patch.create.side_effect = mock_openai_response
component = GPT35Generator(api_key="test-api-key")
results = component.run(prompts=["test-prompt-1", "test-prompt-2"])
assert results == {
"replies": [
["response for these messages --> user: test-prompt-1"],
["response for these messages --> user: test-prompt-2"],
],
"metadata": [
[
{
"model": "gpt-3.5-turbo",
"index": "0",
"finish_reason": "stop",
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
}
],
[
{
"model": "gpt-3.5-turbo",
"index": "0",
"finish_reason": "stop",
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
}
],
],
}
assert gpt35_patch.create.call_count == 2
gpt35_patch.create.assert_any_call(
model="gpt-3.5-turbo",
api_key="test-api-key",
messages=[{"role": "user", "content": "test-prompt-1"}],
stream=False,
)
gpt35_patch.create.assert_any_call(
model="gpt-3.5-turbo",
api_key="test-api-key",
messages=[{"role": "user", "content": "test-prompt-2"}],
stream=False,
)
@pytest.mark.unit
def test_run_with_system_prompt(self):
with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch:
gpt35_patch.create.side_effect = mock_openai_response
component = GPT35Generator(api_key="test-api-key", system_prompt="test-system-prompt")
results = component.run(prompts=["test-prompt-1", "test-prompt-2"])
assert results == {
"replies": [
["response for these messages --> system: test-system-prompt - user: test-prompt-1"],
["response for these messages --> system: test-system-prompt - user: test-prompt-2"],
],
"metadata": [
[
{
"model": "gpt-3.5-turbo",
"index": "0",
"finish_reason": "stop",
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
}
],
[
{
"model": "gpt-3.5-turbo",
"index": "0",
"finish_reason": "stop",
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
}
],
],
}
assert gpt35_patch.create.call_count == 2
gpt35_patch.create.assert_any_call(
model="gpt-3.5-turbo",
api_key="test-api-key",
messages=[
{"role": "system", "content": "test-system-prompt"},
{"role": "user", "content": "test-prompt-1"},
],
stream=False,
)
gpt35_patch.create.assert_any_call(
model="gpt-3.5-turbo",
api_key="test-api-key",
messages=[
{"role": "system", "content": "test-system-prompt"},
{"role": "user", "content": "test-prompt-2"},
],
stream=False,
)
@pytest.mark.unit
def test_run_with_parameters(self):
with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch:
gpt35_patch.create.side_effect = mock_openai_response
component = GPT35Generator(api_key="test-api-key", max_tokens=10)
component.run(prompts=["test-prompt-1", "test-prompt-2"])
assert gpt35_patch.create.call_count == 2
gpt35_patch.create.assert_any_call(
model="gpt-3.5-turbo",
api_key="test-api-key",
messages=[{"role": "user", "content": "test-prompt-1"}],
stream=False,
max_tokens=10,
)
gpt35_patch.create.assert_any_call(
model="gpt-3.5-turbo",
api_key="test-api-key",
messages=[{"role": "user", "content": "test-prompt-2"}],
stream=False,
max_tokens=10,
)
@pytest.mark.unit
def test_run_stream(self):
with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch:
mock_callback = Mock()
mock_callback.side_effect = default_streaming_callback
gpt35_patch.create.side_effect = mock_openai_stream_response
component = GPT35Generator(
api_key="test-api-key", system_prompt="test-system-prompt", streaming_callback=mock_callback
)
results = component.run(prompts=["test-prompt-1", "test-prompt-2"])
assert results == {
"replies": [
["response for these messages --> system: test-system-prompt - user: test-prompt-1 "],
["response for these messages --> system: test-system-prompt - user: test-prompt-2 "],
],
"metadata": [
[{"model": "gpt-3.5-turbo", "index": "0", "finish_reason": "stop"}],
[{"model": "gpt-3.5-turbo", "index": "0", "finish_reason": "stop"}],
],
}
# Calls count: (10 tokens per prompt + 1 token for the role + 1 empty termination token) * 2 prompts
assert mock_callback.call_count == 24
assert gpt35_patch.create.call_count == 2
gpt35_patch.create.assert_any_call(
model="gpt-3.5-turbo",
api_key="test-api-key",
messages=[
{"role": "system", "content": "test-system-prompt"},
{"role": "user", "content": "test-prompt-1"},
],
stream=True,
)
gpt35_patch.create.assert_any_call(
model="gpt-3.5-turbo",
api_key="test-api-key",
messages=[
{"role": "system", "content": "test-system-prompt"},
{"role": "user", "content": "test-prompt-2"},
],
stream=True,
)
@pytest.mark.unit
def test_check_truncated_answers(self, caplog):
component = GPT35Generator(api_key="test-api-key")
metadata = [
{"finish_reason": "stop"},
{"finish_reason": "content_filter"},
{"finish_reason": "length"},
{"finish_reason": "stop"},
]
component._check_truncated_answers(metadata)
assert caplog.records[0].message == (
"2 out of the 4 completions have been truncated before reaching a natural "
"stopping point. Increase the max_tokens parameter to allow for longer completions."
)

View File

@ -1,20 +0,0 @@
import pytest
from haystack.preview.components.generators.openai._helpers import enforce_token_limit
@pytest.mark.unit
def test_enforce_token_limit_above_limit(caplog, mock_tokenizer):
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3)
assert prompt == "This is a"
assert caplog.records[0].message == (
"The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token "
"limit. Reduce the length of the prompt to prevent it from being cut off."
)
@pytest.mark.unit
def test_enforce_token_limit_below_limit(caplog, mock_tokenizer):
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100)
assert prompt == "This is a test prompt."
assert not caplog.records

View File

@ -1,4 +1,4 @@
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest