mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 16:46:58 +00:00
Remove unnecessary GPT4Generator class (#5863)
* Remove GPT4Generator class * Rename GPT35Generator to GPTGenerator * Fix tests * Release notes
This commit is contained in:
parent
f3dc9edd26
commit
cc4f95bf51
@ -33,12 +33,12 @@ def default_streaming_callback(chunk):
|
|||||||
|
|
||||||
|
|
||||||
@component
|
@component
|
||||||
class GPT35Generator:
|
class GPTGenerator:
|
||||||
"""
|
"""
|
||||||
LLM Generator compatible with GPT3.5 (ChatGPT) large language models.
|
LLM Generator compatible with GPT (ChatGPT) large language models.
|
||||||
|
|
||||||
Queries the LLM using OpenAI's API. Invocations are made using OpenAI SDK ('openai' package)
|
Queries the LLM using OpenAI's API. Invocations are made using OpenAI SDK ('openai' package)
|
||||||
See [OpenAI GPT3.5 API](https://platform.openai.com/docs/guides/chat) for more details.
|
See [OpenAI GPT API](https://platform.openai.com/docs/guides/chat) for more details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -115,7 +115,7 @@ class GPT35Generator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "GPT35Generator":
|
def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator":
|
||||||
"""
|
"""
|
||||||
Deserialize this component from a dictionary.
|
Deserialize this component from a dictionary.
|
||||||
"""
|
"""
|
@ -1,71 +0,0 @@
|
|||||||
from typing import Optional, Callable
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from haystack.preview import component
|
|
||||||
from haystack.preview.components.generators.openai.gpt35 import GPT35Generator, API_BASE_URL
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@component
|
|
||||||
class GPT4Generator(GPT35Generator):
|
|
||||||
"""
|
|
||||||
LLM Generator compatible with GPT4 large language models.
|
|
||||||
|
|
||||||
Queries the LLM using OpenAI's API. Invocations are made using OpenAI SDK ('openai' package)
|
|
||||||
See [OpenAI GPT4 API](https://platform.openai.com/docs/guides/chat) for more details.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: str,
|
|
||||||
model_name: str = "gpt-4",
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
streaming_callback: Optional[Callable] = None,
|
|
||||||
api_base_url: str = API_BASE_URL,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates an instance of GPT4Generator for OpenAI's GPT-4 model.
|
|
||||||
|
|
||||||
:param api_key: The OpenAI API key.
|
|
||||||
:param model_name: The name of the model to use.
|
|
||||||
:param system_prompt: An additional message to be sent to the LLM at the beginning of each conversation.
|
|
||||||
Typically, a conversation is formatted with a system message first, followed by alternating messages from
|
|
||||||
the 'user' (the "queries") and the 'assistant' (the "responses"). The system message helps set the behavior
|
|
||||||
of the assistant. For example, you can modify the personality of the assistant or provide specific
|
|
||||||
instructions about how it should behave throughout the conversation.
|
|
||||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
|
||||||
The callback function should accept two parameters: the token received from the stream and **kwargs.
|
|
||||||
The callback function should return the token to be sent to the stream. If the callback function is not
|
|
||||||
provided, the token is printed to stdout.
|
|
||||||
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
|
||||||
:param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI
|
|
||||||
endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details.
|
|
||||||
Some of the supported parameters:
|
|
||||||
- `max_tokens`: The maximum number of tokens the output text can have.
|
|
||||||
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
|
|
||||||
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
|
|
||||||
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
|
|
||||||
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
|
|
||||||
comprising the top 10% probability mass are considered.
|
|
||||||
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
|
|
||||||
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
|
|
||||||
- `stop`: One or more sequences after which the LLM should stop generating tokens.
|
|
||||||
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
|
|
||||||
the model will be less likely to repeat the same token in the text.
|
|
||||||
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
|
|
||||||
Bigger values mean the model will be less likely to repeat the same token in the text.
|
|
||||||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
|
|
||||||
values are the bias to add to that token.
|
|
||||||
"""
|
|
||||||
super().__init__(
|
|
||||||
api_key=api_key,
|
|
||||||
model_name=model_name,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
streaming_callback=streaming_callback,
|
|
||||||
api_base_url=api_base_url,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
@ -1,2 +1,2 @@
|
|||||||
preview:
|
preview:
|
||||||
- Introduce `GPT35Generator`, a class that can generate completions using OpenAI Chat models like GPT3.5 and GPT4.
|
- Introduce `GPTGenerator`, a class that can generate completions using OpenAI Chat models like GPT3.5 and GPT4.
|
||||||
|
@ -1,110 +0,0 @@
|
|||||||
from unittest.mock import patch, Mock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from haystack.preview.components.generators.openai.gpt4 import GPT4Generator, API_BASE_URL
|
|
||||||
from haystack.preview.components.generators.openai.gpt35 import default_streaming_callback
|
|
||||||
|
|
||||||
|
|
||||||
class TestGPT4Generator:
|
|
||||||
@pytest.mark.unit
|
|
||||||
def test_init_default(self):
|
|
||||||
component = GPT4Generator(api_key="test-api-key")
|
|
||||||
assert component.system_prompt is None
|
|
||||||
assert component.api_key == "test-api-key"
|
|
||||||
assert component.model_name == "gpt-4"
|
|
||||||
assert component.streaming_callback is None
|
|
||||||
assert component.api_base_url == API_BASE_URL
|
|
||||||
assert component.model_parameters == {}
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
def test_init_with_parameters(self):
|
|
||||||
callback = lambda x: x
|
|
||||||
component = GPT4Generator(
|
|
||||||
api_key="test-api-key",
|
|
||||||
model_name="gpt-4-32k",
|
|
||||||
system_prompt="test-system-prompt",
|
|
||||||
max_tokens=10,
|
|
||||||
some_test_param="test-params",
|
|
||||||
streaming_callback=callback,
|
|
||||||
api_base_url="test-base-url",
|
|
||||||
)
|
|
||||||
assert component.system_prompt == "test-system-prompt"
|
|
||||||
assert component.api_key == "test-api-key"
|
|
||||||
assert component.model_name == "gpt-4-32k"
|
|
||||||
assert component.streaming_callback == callback
|
|
||||||
assert component.api_base_url == "test-base-url"
|
|
||||||
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
def test_to_dict_default(self):
|
|
||||||
component = GPT4Generator(api_key="test-api-key")
|
|
||||||
data = component.to_dict()
|
|
||||||
assert data == {
|
|
||||||
"type": "GPT4Generator",
|
|
||||||
"init_parameters": {
|
|
||||||
"api_key": "test-api-key",
|
|
||||||
"model_name": "gpt-4",
|
|
||||||
"system_prompt": None,
|
|
||||||
"streaming_callback": None,
|
|
||||||
"api_base_url": API_BASE_URL,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
def test_to_dict_with_parameters(self):
|
|
||||||
component = GPT4Generator(
|
|
||||||
api_key="test-api-key",
|
|
||||||
model_name="gpt-4-32k",
|
|
||||||
system_prompt="test-system-prompt",
|
|
||||||
max_tokens=10,
|
|
||||||
some_test_param="test-params",
|
|
||||||
streaming_callback=default_streaming_callback,
|
|
||||||
api_base_url="test-base-url",
|
|
||||||
)
|
|
||||||
data = component.to_dict()
|
|
||||||
assert data == {
|
|
||||||
"type": "GPT4Generator",
|
|
||||||
"init_parameters": {
|
|
||||||
"api_key": "test-api-key",
|
|
||||||
"model_name": "gpt-4-32k",
|
|
||||||
"system_prompt": "test-system-prompt",
|
|
||||||
"max_tokens": 10,
|
|
||||||
"some_test_param": "test-params",
|
|
||||||
"api_base_url": "test-base-url",
|
|
||||||
"streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
def test_from_dict_default(self):
|
|
||||||
data = {"type": "GPT4Generator", "init_parameters": {"api_key": "test-api-key"}}
|
|
||||||
component = GPT4Generator.from_dict(data)
|
|
||||||
assert component.system_prompt is None
|
|
||||||
assert component.api_key == "test-api-key"
|
|
||||||
assert component.model_name == "gpt-4"
|
|
||||||
assert component.streaming_callback is None
|
|
||||||
assert component.api_base_url == API_BASE_URL
|
|
||||||
assert component.model_parameters == {}
|
|
||||||
|
|
||||||
@pytest.mark.unit
|
|
||||||
def test_from_dict(self):
|
|
||||||
data = {
|
|
||||||
"type": "GPT4Generator",
|
|
||||||
"init_parameters": {
|
|
||||||
"api_key": "test-api-key",
|
|
||||||
"model_name": "gpt-4-32k",
|
|
||||||
"system_prompt": "test-system-prompt",
|
|
||||||
"max_tokens": 10,
|
|
||||||
"some_test_param": "test-params",
|
|
||||||
"api_base_url": "test-base-url",
|
|
||||||
"streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
component = GPT4Generator.from_dict(data)
|
|
||||||
assert component.system_prompt == "test-system-prompt"
|
|
||||||
assert component.api_key == "test-api-key"
|
|
||||||
assert component.model_name == "gpt-4-32k"
|
|
||||||
assert component.streaming_callback == default_streaming_callback
|
|
||||||
assert component.api_base_url == "test-base-url"
|
|
||||||
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
|
|
@ -6,8 +6,8 @@ import pytest
|
|||||||
import openai
|
import openai
|
||||||
from openai.util import convert_to_openai_object
|
from openai.util import convert_to_openai_object
|
||||||
|
|
||||||
from haystack.preview.components.generators.openai.gpt35 import GPT35Generator
|
from haystack.preview.components.generators.openai.gpt import GPTGenerator
|
||||||
from haystack.preview.components.generators.openai.gpt35 import default_streaming_callback
|
from haystack.preview.components.generators.openai.gpt import default_streaming_callback
|
||||||
|
|
||||||
|
|
||||||
def mock_openai_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion:
|
def mock_openai_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion:
|
||||||
@ -42,10 +42,10 @@ def mock_openai_stream_response(messages: str, model: str = "gpt-3.5-turbo-0301"
|
|||||||
yield convert_to_openai_object(base_dict)
|
yield convert_to_openai_object(base_dict)
|
||||||
|
|
||||||
|
|
||||||
class TestGPT35Generator:
|
class TestGPTGenerator:
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_init_default(self):
|
def test_init_default(self):
|
||||||
component = GPT35Generator(api_key="test-api-key")
|
component = GPTGenerator(api_key="test-api-key")
|
||||||
assert component.system_prompt is None
|
assert component.system_prompt is None
|
||||||
assert component.api_key == "test-api-key"
|
assert component.api_key == "test-api-key"
|
||||||
assert component.model_name == "gpt-3.5-turbo"
|
assert component.model_name == "gpt-3.5-turbo"
|
||||||
@ -56,7 +56,7 @@ class TestGPT35Generator:
|
|||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_init_with_parameters(self):
|
def test_init_with_parameters(self):
|
||||||
callback = lambda x: x
|
callback = lambda x: x
|
||||||
component = GPT35Generator(
|
component = GPTGenerator(
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
model_name="gpt-4",
|
model_name="gpt-4",
|
||||||
system_prompt="test-system-prompt",
|
system_prompt="test-system-prompt",
|
||||||
@ -74,10 +74,10 @@ class TestGPT35Generator:
|
|||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_to_dict_default(self):
|
def test_to_dict_default(self):
|
||||||
component = GPT35Generator(api_key="test-api-key")
|
component = GPTGenerator(api_key="test-api-key")
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "GPT35Generator",
|
"type": "GPTGenerator",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"api_key": "test-api-key",
|
"api_key": "test-api-key",
|
||||||
"model_name": "gpt-3.5-turbo",
|
"model_name": "gpt-3.5-turbo",
|
||||||
@ -89,7 +89,7 @@ class TestGPT35Generator:
|
|||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_to_dict_with_parameters(self):
|
def test_to_dict_with_parameters(self):
|
||||||
component = GPT35Generator(
|
component = GPTGenerator(
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
model_name="gpt-4",
|
model_name="gpt-4",
|
||||||
system_prompt="test-system-prompt",
|
system_prompt="test-system-prompt",
|
||||||
@ -100,7 +100,7 @@ class TestGPT35Generator:
|
|||||||
)
|
)
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "GPT35Generator",
|
"type": "GPTGenerator",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"api_key": "test-api-key",
|
"api_key": "test-api-key",
|
||||||
"model_name": "gpt-4",
|
"model_name": "gpt-4",
|
||||||
@ -108,13 +108,13 @@ class TestGPT35Generator:
|
|||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
"some_test_param": "test-params",
|
"some_test_param": "test-params",
|
||||||
"api_base_url": "test-base-url",
|
"api_base_url": "test-base-url",
|
||||||
"streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback",
|
"streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_to_dict_with_lambda_streaming_callback(self):
|
def test_to_dict_with_lambda_streaming_callback(self):
|
||||||
component = GPT35Generator(
|
component = GPTGenerator(
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
model_name="gpt-4",
|
model_name="gpt-4",
|
||||||
system_prompt="test-system-prompt",
|
system_prompt="test-system-prompt",
|
||||||
@ -125,7 +125,7 @@ class TestGPT35Generator:
|
|||||||
)
|
)
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "GPT35Generator",
|
"type": "GPTGenerator",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"api_key": "test-api-key",
|
"api_key": "test-api-key",
|
||||||
"model_name": "gpt-4",
|
"model_name": "gpt-4",
|
||||||
@ -133,14 +133,14 @@ class TestGPT35Generator:
|
|||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
"some_test_param": "test-params",
|
"some_test_param": "test-params",
|
||||||
"api_base_url": "test-base-url",
|
"api_base_url": "test-base-url",
|
||||||
"streaming_callback": "test_gpt35_generator.<lambda>",
|
"streaming_callback": "test_gpt_generator.<lambda>",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_from_dict(self):
|
def test_from_dict(self):
|
||||||
data = {
|
data = {
|
||||||
"type": "GPT35Generator",
|
"type": "GPTGenerator",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"api_key": "test-api-key",
|
"api_key": "test-api-key",
|
||||||
"model_name": "gpt-4",
|
"model_name": "gpt-4",
|
||||||
@ -148,10 +148,10 @@ class TestGPT35Generator:
|
|||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
"some_test_param": "test-params",
|
"some_test_param": "test-params",
|
||||||
"api_base_url": "test-base-url",
|
"api_base_url": "test-base-url",
|
||||||
"streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback",
|
"streaming_callback": "haystack.preview.components.generators.openai.gpt.default_streaming_callback",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
component = GPT35Generator.from_dict(data)
|
component = GPTGenerator.from_dict(data)
|
||||||
assert component.system_prompt == "test-system-prompt"
|
assert component.system_prompt == "test-system-prompt"
|
||||||
assert component.api_key == "test-api-key"
|
assert component.api_key == "test-api-key"
|
||||||
assert component.model_name == "gpt-4"
|
assert component.model_name == "gpt-4"
|
||||||
@ -161,9 +161,9 @@ class TestGPT35Generator:
|
|||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_run_no_system_prompt(self):
|
def test_run_no_system_prompt(self):
|
||||||
with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch:
|
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
|
||||||
gpt35_patch.create.side_effect = mock_openai_response
|
gpt_patch.create.side_effect = mock_openai_response
|
||||||
component = GPT35Generator(api_key="test-api-key")
|
component = GPTGenerator(api_key="test-api-key")
|
||||||
results = component.run(prompt="test-prompt-1")
|
results = component.run(prompt="test-prompt-1")
|
||||||
assert results == {
|
assert results == {
|
||||||
"replies": ["response for these messages --> user: test-prompt-1"],
|
"replies": ["response for these messages --> user: test-prompt-1"],
|
||||||
@ -176,7 +176,7 @@ class TestGPT35Generator:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
gpt35_patch.create.assert_called_once_with(
|
gpt_patch.create.assert_called_once_with(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
messages=[{"role": "user", "content": "test-prompt-1"}],
|
messages=[{"role": "user", "content": "test-prompt-1"}],
|
||||||
@ -185,9 +185,9 @@ class TestGPT35Generator:
|
|||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_run_with_system_prompt(self):
|
def test_run_with_system_prompt(self):
|
||||||
with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch:
|
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
|
||||||
gpt35_patch.create.side_effect = mock_openai_response
|
gpt_patch.create.side_effect = mock_openai_response
|
||||||
component = GPT35Generator(api_key="test-api-key", system_prompt="test-system-prompt")
|
component = GPTGenerator(api_key="test-api-key", system_prompt="test-system-prompt")
|
||||||
results = component.run(prompt="test-prompt-1")
|
results = component.run(prompt="test-prompt-1")
|
||||||
assert results == {
|
assert results == {
|
||||||
"replies": ["response for these messages --> system: test-system-prompt - user: test-prompt-1"],
|
"replies": ["response for these messages --> system: test-system-prompt - user: test-prompt-1"],
|
||||||
@ -200,7 +200,7 @@ class TestGPT35Generator:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
gpt35_patch.create.assert_called_once_with(
|
gpt_patch.create.assert_called_once_with(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
messages=[
|
messages=[
|
||||||
@ -212,11 +212,11 @@ class TestGPT35Generator:
|
|||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_run_with_parameters(self):
|
def test_run_with_parameters(self):
|
||||||
with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch:
|
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
|
||||||
gpt35_patch.create.side_effect = mock_openai_response
|
gpt_patch.create.side_effect = mock_openai_response
|
||||||
component = GPT35Generator(api_key="test-api-key", max_tokens=10)
|
component = GPTGenerator(api_key="test-api-key", max_tokens=10)
|
||||||
component.run(prompt="test-prompt-1")
|
component.run(prompt="test-prompt-1")
|
||||||
gpt35_patch.create.assert_called_once_with(
|
gpt_patch.create.assert_called_once_with(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
messages=[{"role": "user", "content": "test-prompt-1"}],
|
messages=[{"role": "user", "content": "test-prompt-1"}],
|
||||||
@ -226,11 +226,11 @@ class TestGPT35Generator:
|
|||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_run_stream(self):
|
def test_run_stream(self):
|
||||||
with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch:
|
with patch("haystack.preview.components.generators.openai.gpt.openai.ChatCompletion") as gpt_patch:
|
||||||
mock_callback = Mock()
|
mock_callback = Mock()
|
||||||
mock_callback.side_effect = default_streaming_callback
|
mock_callback.side_effect = default_streaming_callback
|
||||||
gpt35_patch.create.side_effect = mock_openai_stream_response
|
gpt_patch.create.side_effect = mock_openai_stream_response
|
||||||
component = GPT35Generator(
|
component = GPTGenerator(
|
||||||
api_key="test-api-key", system_prompt="test-system-prompt", streaming_callback=mock_callback
|
api_key="test-api-key", system_prompt="test-system-prompt", streaming_callback=mock_callback
|
||||||
)
|
)
|
||||||
results = component.run(prompt="test-prompt-1")
|
results = component.run(prompt="test-prompt-1")
|
||||||
@ -240,7 +240,7 @@ class TestGPT35Generator:
|
|||||||
}
|
}
|
||||||
# Calls count: 10 tokens per prompt + 1 token for the role + 1 empty termination token
|
# Calls count: 10 tokens per prompt + 1 token for the role + 1 empty termination token
|
||||||
assert mock_callback.call_count == 12
|
assert mock_callback.call_count == 12
|
||||||
gpt35_patch.create.assert_called_once_with(
|
gpt_patch.create.assert_called_once_with(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
api_key="test-api-key",
|
api_key="test-api-key",
|
||||||
messages=[
|
messages=[
|
||||||
@ -252,7 +252,7 @@ class TestGPT35Generator:
|
|||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_check_truncated_answers(self, caplog):
|
def test_check_truncated_answers(self, caplog):
|
||||||
component = GPT35Generator(api_key="test-api-key")
|
component = GPTGenerator(api_key="test-api-key")
|
||||||
metadata = [
|
metadata = [
|
||||||
{"finish_reason": "stop"},
|
{"finish_reason": "stop"},
|
||||||
{"finish_reason": "content_filter"},
|
{"finish_reason": "content_filter"},
|
||||||
@ -270,8 +270,8 @@ class TestGPT35Generator:
|
|||||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||||
)
|
)
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_gpt35_generator_run(self):
|
def test_gpt_generator_run(self):
|
||||||
component = GPT35Generator(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
|
component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
|
||||||
results = component.run(prompt="What's the capital of France?")
|
results = component.run(prompt="What's the capital of France?")
|
||||||
assert len(results["replies"]) == 1
|
assert len(results["replies"]) == 1
|
||||||
assert "Paris" in results["replies"][0]
|
assert "Paris" in results["replies"][0]
|
||||||
@ -284,10 +284,8 @@ class TestGPT35Generator:
|
|||||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||||
)
|
)
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_gpt35_generator_run_wrong_model_name(self):
|
def test_gpt_generator_run_wrong_model_name(self):
|
||||||
component = GPT35Generator(
|
component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1)
|
||||||
model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1
|
|
||||||
)
|
|
||||||
with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"):
|
with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"):
|
||||||
component.run(prompt="What's the capital of France?")
|
component.run(prompt="What's the capital of France?")
|
||||||
|
|
||||||
@ -296,7 +294,7 @@ class TestGPT35Generator:
|
|||||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||||
)
|
)
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_gpt35_generator_run_streaming(self):
|
def test_gpt_generator_run_streaming(self):
|
||||||
class Callback:
|
class Callback:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.responses = ""
|
self.responses = ""
|
||||||
@ -306,7 +304,7 @@ class TestGPT35Generator:
|
|||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
callback = Callback()
|
callback = Callback()
|
||||||
component = GPT35Generator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1)
|
component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1)
|
||||||
results = component.run(prompt="What's the capital of France?")
|
results = component.run(prompt="What's the capital of France?")
|
||||||
|
|
||||||
assert len(results["replies"]) == 1
|
assert len(results["replies"]) == 1
|
Loading…
x
Reference in New Issue
Block a user