mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-04 02:57:34 +00:00
feat: GPT4Generator (#5744)
* add gpt4generator * add e2e * add tests * reno * fix e2e * Update test/preview/components/generators/openai/test_gpt4_generator.py Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
This commit is contained in:
parent
75dc60b0bb
commit
2c4d839b64
@ -2,21 +2,21 @@ import os
|
||||
import pytest
|
||||
import openai
|
||||
from haystack.preview.components.generators.openai.gpt35 import GPT35Generator
|
||||
from haystack.preview.components.generators.openai.gpt4 import GPT4Generator
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_gpt35_generator_run():
|
||||
component = GPT35Generator(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
|
||||
@pytest.mark.parametrize("generator_class,model_name", [(GPT35Generator, "gpt-3.5"), (GPT4Generator, "gpt-4")])
|
||||
def test_gpt35_generator_run(generator_class, model_name):
|
||||
component = generator_class(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
|
||||
results = component.run(prompt="What's the capital of France?")
|
||||
|
||||
assert len(results["replies"]) == 1
|
||||
assert "Paris" in results["replies"][0]
|
||||
|
||||
assert len(results["metadata"]) == 1
|
||||
assert "gpt-3.5-turbo" in results["metadata"][0]["model"]
|
||||
assert model_name in results["metadata"][0]["model"]
|
||||
assert "stop" == results["metadata"][0]["finish_reason"]
|
||||
|
||||
|
||||
@ -24,8 +24,9 @@ def test_gpt35_generator_run():
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_gpt35_generator_run_wrong_model_name():
|
||||
component = GPT35Generator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1)
|
||||
@pytest.mark.parametrize("generator_class", [GPT35Generator, GPT4Generator])
|
||||
def test_gpt35_generator_run_wrong_model_name(generator_class):
|
||||
component = generator_class(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1)
|
||||
with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"):
|
||||
component.run(prompt="What's the capital of France?")
|
||||
|
||||
@ -34,7 +35,8 @@ def test_gpt35_generator_run_wrong_model_name():
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_gpt35_generator_run_streaming():
|
||||
@pytest.mark.parametrize("generator_class,model_name", [(GPT35Generator, "gpt-3.5"), (GPT4Generator, "gpt-4")])
|
||||
def test_gpt35_generator_run_streaming(generator_class, model_name):
|
||||
class Callback:
|
||||
def __init__(self):
|
||||
self.responses = ""
|
||||
@ -44,14 +46,14 @@ def test_gpt35_generator_run_streaming():
|
||||
return chunk
|
||||
|
||||
callback = Callback()
|
||||
component = GPT35Generator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1)
|
||||
component = generator_class(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1)
|
||||
results = component.run(prompt="What's the capital of France?")
|
||||
|
||||
assert len(results["replies"]) == 1
|
||||
assert "Paris" in results["replies"][0]
|
||||
|
||||
assert len(results["metadata"]) == 1
|
||||
assert "gpt-3.5-turbo" in results["metadata"][0]["model"]
|
||||
assert model_name in results["metadata"][0]["model"]
|
||||
assert "stop" == results["metadata"][0]["finish_reason"]
|
||||
|
||||
assert callback.responses == results["replies"][0]
|
||||
|
||||
@ -13,6 +13,9 @@ from haystack.preview import component, default_from_dict, default_to_dict, Dese
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
API_BASE_URL = "https://api.openai.com/v1"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ChatMessage:
|
||||
content: str
|
||||
@ -44,7 +47,7 @@ class GPT35Generator:
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
system_prompt: Optional[str] = None,
|
||||
streaming_callback: Optional[Callable] = None,
|
||||
api_base_url: str = "https://api.openai.com/v1",
|
||||
api_base_url: str = API_BASE_URL,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
71
haystack/preview/components/generators/openai/gpt4.py
Normal file
71
haystack/preview/components/generators/openai/gpt4.py
Normal file
@ -0,0 +1,71 @@
|
||||
from typing import Optional, Callable
|
||||
|
||||
import logging
|
||||
|
||||
from haystack.preview import component
|
||||
from haystack.preview.components.generators.openai.gpt35 import GPT35Generator, API_BASE_URL
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class GPT4Generator(GPT35Generator):
|
||||
"""
|
||||
LLM Generator compatible with GPT4 large language models.
|
||||
|
||||
Queries the LLM using OpenAI's API. Invocations are made using OpenAI SDK ('openai' package)
|
||||
See [OpenAI GPT4 API](https://platform.openai.com/docs/guides/chat) for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model_name: str = "gpt-4",
|
||||
system_prompt: Optional[str] = None,
|
||||
streaming_callback: Optional[Callable] = None,
|
||||
api_base_url: str = API_BASE_URL,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates an instance of GPT4Generator for OpenAI's GPT-4 model.
|
||||
|
||||
:param api_key: The OpenAI API key.
|
||||
:param model_name: The name of the model to use.
|
||||
:param system_prompt: An additional message to be sent to the LLM at the beginning of each conversation.
|
||||
Typically, a conversation is formatted with a system message first, followed by alternating messages from
|
||||
the 'user' (the "queries") and the 'assistant' (the "responses"). The system message helps set the behavior
|
||||
of the assistant. For example, you can modify the personality of the assistant or provide specific
|
||||
instructions about how it should behave throughout the conversation.
|
||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||
The callback function should accept two parameters: the token received from the stream and **kwargs.
|
||||
The callback function should return the token to be sent to the stream. If the callback function is not
|
||||
provided, the token is printed to stdout.
|
||||
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
||||
:param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI
|
||||
endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details.
|
||||
Some of the supported parameters:
|
||||
- `max_tokens`: The maximum number of tokens the output text can have.
|
||||
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
|
||||
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
|
||||
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
|
||||
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
|
||||
comprising the top 10% probability mass are considered.
|
||||
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
|
||||
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
|
||||
- `stop`: One or more sequences after which the LLM should stop generating tokens.
|
||||
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
|
||||
the model will be less likely to repeat the same token in the text.
|
||||
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
|
||||
Bigger values mean the model will be less likely to repeat the same token in the text.
|
||||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
|
||||
values are the bias to add to that token.
|
||||
"""
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
system_prompt=system_prompt,
|
||||
streaming_callback=streaming_callback,
|
||||
api_base_url=api_base_url,
|
||||
**kwargs,
|
||||
)
|
||||
@ -0,0 +1,3 @@
|
||||
|
||||
preview:
|
||||
- Adds `GPT4Generator`, an LLM component based on `GPT35Generator`
|
||||
110
test/preview/components/generators/openai/test_gpt4_generator.py
Normal file
110
test/preview/components/generators/openai/test_gpt4_generator.py
Normal file
@ -0,0 +1,110 @@
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview.components.generators.openai.gpt4 import GPT4Generator, API_BASE_URL
|
||||
from haystack.preview.components.generators.openai.gpt35 import default_streaming_callback
|
||||
|
||||
|
||||
class TestGPT4Generator:
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self):
|
||||
component = GPT4Generator(api_key="test-api-key")
|
||||
assert component.system_prompt is None
|
||||
assert component.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.streaming_callback is None
|
||||
assert component.api_base_url == API_BASE_URL
|
||||
assert component.model_parameters == {}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
callback = lambda x: x
|
||||
component = GPT4Generator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4-32k",
|
||||
system_prompt="test-system-prompt",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=callback,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
assert component.system_prompt == "test-system-prompt"
|
||||
assert component.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-4-32k"
|
||||
assert component.streaming_callback == callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_default(self):
|
||||
component = GPT4Generator(api_key="test-api-key")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "GPT4Generator",
|
||||
"init_parameters": {
|
||||
"api_key": "test-api-key",
|
||||
"model_name": "gpt-4",
|
||||
"system_prompt": None,
|
||||
"streaming_callback": None,
|
||||
"api_base_url": API_BASE_URL,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_parameters(self):
|
||||
component = GPT4Generator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4-32k",
|
||||
system_prompt="test-system-prompt",
|
||||
max_tokens=10,
|
||||
some_test_param="test-params",
|
||||
streaming_callback=default_streaming_callback,
|
||||
api_base_url="test-base-url",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "GPT4Generator",
|
||||
"init_parameters": {
|
||||
"api_key": "test-api-key",
|
||||
"model_name": "gpt-4-32k",
|
||||
"system_prompt": "test-system-prompt",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_default(self):
|
||||
data = {"type": "GPT4Generator", "init_parameters": {"api_key": "test-api-key"}}
|
||||
component = GPT4Generator.from_dict(data)
|
||||
assert component.system_prompt is None
|
||||
assert component.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.streaming_callback is None
|
||||
assert component.api_base_url == API_BASE_URL
|
||||
assert component.model_parameters == {}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "GPT4Generator",
|
||||
"init_parameters": {
|
||||
"api_key": "test-api-key",
|
||||
"model_name": "gpt-4-32k",
|
||||
"system_prompt": "test-system-prompt",
|
||||
"max_tokens": 10,
|
||||
"some_test_param": "test-params",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback",
|
||||
},
|
||||
}
|
||||
component = GPT4Generator.from_dict(data)
|
||||
assert component.system_prompt == "test-system-prompt"
|
||||
assert component.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-4-32k"
|
||||
assert component.streaming_callback == default_streaming_callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
Loading…
x
Reference in New Issue
Block a user