feat: GPT4Generator (#5744)

* add gpt4generator

* add e2e

* add tests

* reno

* fix e2e

* Update test/preview/components/generators/openai/test_gpt4_generator.py

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>

---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
This commit is contained in:
ZanSara 2023-09-13 09:07:09 +01:00 committed by GitHub
parent 75dc60b0bb
commit 2c4d839b64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 200 additions and 11 deletions

View File

@ -2,21 +2,21 @@ import os
import pytest
import openai
from haystack.preview.components.generators.openai.gpt35 import GPT35Generator
from haystack.preview.components.generators.openai.gpt4 import GPT4Generator
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_gpt35_generator_run():
component = GPT35Generator(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
@pytest.mark.parametrize("generator_class,model_name", [(GPT35Generator, "gpt-3.5"), (GPT4Generator, "gpt-4")])
def test_gpt35_generator_run(generator_class, model_name):
component = generator_class(api_key=os.environ.get("OPENAI_API_KEY"), n=1)
results = component.run(prompt="What's the capital of France?")
assert len(results["replies"]) == 1
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert "gpt-3.5-turbo" in results["metadata"][0]["model"]
assert model_name in results["metadata"][0]["model"]
assert "stop" == results["metadata"][0]["finish_reason"]
@ -24,8 +24,9 @@ def test_gpt35_generator_run():
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_gpt35_generator_run_wrong_model_name():
component = GPT35Generator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1)
@pytest.mark.parametrize("generator_class", [GPT35Generator, GPT4Generator])
def test_gpt35_generator_run_wrong_model_name(generator_class):
component = generator_class(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1)
with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"):
component.run(prompt="What's the capital of France?")
@ -34,7 +35,8 @@ def test_gpt35_generator_run_wrong_model_name():
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_gpt35_generator_run_streaming():
@pytest.mark.parametrize("generator_class,model_name", [(GPT35Generator, "gpt-3.5"), (GPT4Generator, "gpt-4")])
def test_gpt35_generator_run_streaming(generator_class, model_name):
class Callback:
def __init__(self):
self.responses = ""
@ -44,14 +46,14 @@ def test_gpt35_generator_run_streaming():
return chunk
callback = Callback()
component = GPT35Generator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1)
component = generator_class(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1)
results = component.run(prompt="What's the capital of France?")
assert len(results["replies"]) == 1
assert "Paris" in results["replies"][0]
assert len(results["metadata"]) == 1
assert "gpt-3.5-turbo" in results["metadata"][0]["model"]
assert model_name in results["metadata"][0]["model"]
assert "stop" == results["metadata"][0]["finish_reason"]
assert callback.responses == results["replies"][0]

View File

@ -13,6 +13,9 @@ from haystack.preview import component, default_from_dict, default_to_dict, Dese
logger = logging.getLogger(__name__)
API_BASE_URL = "https://api.openai.com/v1"
@dataclass
class _ChatMessage:
content: str
@ -44,7 +47,7 @@ class GPT35Generator:
model_name: str = "gpt-3.5-turbo",
system_prompt: Optional[str] = None,
streaming_callback: Optional[Callable] = None,
api_base_url: str = "https://api.openai.com/v1",
api_base_url: str = API_BASE_URL,
**kwargs,
):
"""

View File

@ -0,0 +1,71 @@
from typing import Optional, Callable
import logging
from haystack.preview import component
from haystack.preview.components.generators.openai.gpt35 import GPT35Generator, API_BASE_URL
logger = logging.getLogger(__name__)
@component
class GPT4Generator(GPT35Generator):
"""
LLM Generator compatible with GPT4 large language models.
Queries the LLM using OpenAI's API. Invocations are made using OpenAI SDK ('openai' package)
See [OpenAI GPT4 API](https://platform.openai.com/docs/guides/chat) for more details.
"""
def __init__(
self,
api_key: str,
model_name: str = "gpt-4",
system_prompt: Optional[str] = None,
streaming_callback: Optional[Callable] = None,
api_base_url: str = API_BASE_URL,
**kwargs,
):
"""
Creates an instance of GPT4Generator for OpenAI's GPT-4 model.
:param api_key: The OpenAI API key.
:param model_name: The name of the model to use.
:param system_prompt: An additional message to be sent to the LLM at the beginning of each conversation.
Typically, a conversation is formatted with a system message first, followed by alternating messages from
the 'user' (the "queries") and the 'assistant' (the "responses"). The system message helps set the behavior
of the assistant. For example, you can modify the personality of the assistant or provide specific
instructions about how it should behave throughout the conversation.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
provided, the token is printed to stdout.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI
endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details.
Some of the supported parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
values are the bias to add to that token.
"""
super().__init__(
api_key=api_key,
model_name=model_name,
system_prompt=system_prompt,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
**kwargs,
)

View File

@ -0,0 +1,3 @@
preview:
- Adds `GPT4Generator`, an LLM component based on `GPT35Generator`

View File

@ -0,0 +1,110 @@
from unittest.mock import patch, Mock
import pytest
from haystack.preview.components.generators.openai.gpt4 import GPT4Generator, API_BASE_URL
from haystack.preview.components.generators.openai.gpt35 import default_streaming_callback
class TestGPT4Generator:
@pytest.mark.unit
def test_init_default(self):
component = GPT4Generator(api_key="test-api-key")
assert component.system_prompt is None
assert component.api_key == "test-api-key"
assert component.model_name == "gpt-4"
assert component.streaming_callback is None
assert component.api_base_url == API_BASE_URL
assert component.model_parameters == {}
@pytest.mark.unit
def test_init_with_parameters(self):
callback = lambda x: x
component = GPT4Generator(
api_key="test-api-key",
model_name="gpt-4-32k",
system_prompt="test-system-prompt",
max_tokens=10,
some_test_param="test-params",
streaming_callback=callback,
api_base_url="test-base-url",
)
assert component.system_prompt == "test-system-prompt"
assert component.api_key == "test-api-key"
assert component.model_name == "gpt-4-32k"
assert component.streaming_callback == callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
@pytest.mark.unit
def test_to_dict_default(self):
component = GPT4Generator(api_key="test-api-key")
data = component.to_dict()
assert data == {
"type": "GPT4Generator",
"init_parameters": {
"api_key": "test-api-key",
"model_name": "gpt-4",
"system_prompt": None,
"streaming_callback": None,
"api_base_url": API_BASE_URL,
},
}
@pytest.mark.unit
def test_to_dict_with_parameters(self):
component = GPT4Generator(
api_key="test-api-key",
model_name="gpt-4-32k",
system_prompt="test-system-prompt",
max_tokens=10,
some_test_param="test-params",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
)
data = component.to_dict()
assert data == {
"type": "GPT4Generator",
"init_parameters": {
"api_key": "test-api-key",
"model_name": "gpt-4-32k",
"system_prompt": "test-system-prompt",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback",
},
}
@pytest.mark.unit
def test_from_dict_default(self):
data = {"type": "GPT4Generator", "init_parameters": {"api_key": "test-api-key"}}
component = GPT4Generator.from_dict(data)
assert component.system_prompt is None
assert component.api_key == "test-api-key"
assert component.model_name == "gpt-4"
assert component.streaming_callback is None
assert component.api_base_url == API_BASE_URL
assert component.model_parameters == {}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "GPT4Generator",
"init_parameters": {
"api_key": "test-api-key",
"model_name": "gpt-4-32k",
"system_prompt": "test-system-prompt",
"max_tokens": 10,
"some_test_param": "test-params",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback",
},
}
component = GPT4Generator.from_dict(data)
assert component.system_prompt == "test-system-prompt"
assert component.api_key == "test-api-key"
assert component.model_name == "gpt-4-32k"
assert component.streaming_callback == default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}