chore: Rename GPT generators, deprecate old names (#6626)

This commit is contained in:
Vladimir Blagojevic 2023-12-22 19:37:29 +01:00 committed by GitHub
parent c0f1dab454
commit 506ab81d26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 144 additions and 81 deletions

View File

@ -7,7 +7,7 @@ from haystack.document_stores import InMemoryDocumentStore
from haystack.components.writers import DocumentWriter
from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
from haystack.components.generators import GPTGenerator
from haystack.components.generators import OpenAIGenerator
from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.components.builders.prompt_builder import PromptBuilder
@ -30,7 +30,7 @@ def test_bm25_rag_pipeline(tmp_path):
rag_pipeline = Pipeline()
rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
rag_pipeline.connect("retriever", "prompt_builder.documents")
rag_pipeline.connect("prompt_builder", "llm")
@ -102,7 +102,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever"
)
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
rag_pipeline.connect("text_embedder", "retriever")
rag_pipeline.connect("retriever", "prompt_builder.documents")

View File

@ -2,7 +2,7 @@ import json
import os
from haystack import Pipeline
from haystack.components.generators.openai import GPTGenerator
from haystack.components.generators.openai import OpenAIGenerator
from haystack.components.builders.prompt_builder import PromptBuilder
import random
from haystack import component
@ -83,8 +83,8 @@ prompt_template = """
# Let's build the pipeline (Make sure to set OPENAI_API_KEY as an environment variable)
pipeline = Pipeline(max_loops_allowed=5)
pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
pipeline.add_component(instance=GPTGenerator(), name="llm")
pipeline.add_component(instance=OutputParser(pydantic_model=CitiesData), name="output_parser")
pipeline.add_component(instance=OpenAIGenerator(), name="llm")
pipeline.add_component(instance=OutputParser(pydantic_model=CitiesData), name="output_parser") # type: ignore
pipeline.connect("prompt_builder", "llm")
pipeline.connect("llm", "output_parser")

View File

@ -2,7 +2,7 @@ import os
from haystack import Pipeline, Document
from haystack.document_stores import InMemoryDocumentStore
from haystack.components.retrievers import InMemoryBM25Retriever
from haystack.components.generators import GPTGenerator
from haystack.components.generators import OpenAIGenerator
from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.components.builders.prompt_builder import PromptBuilder
@ -20,7 +20,7 @@ documents = [
]
document_store.write_documents(documents)
# Build a RAG pipeline with a Retriever to get relevant documents to the query and a GPTGenerator interacting with LLMs using a custom prompt.
# Build a RAG pipeline with a Retriever to get relevant documents to the query and a OpenAIGenerator interacting with LLMs using a custom prompt.
prompt_template = """
Given these documents, answer the question.\nDocuments:
{% for doc in documents %}
@ -33,7 +33,7 @@ Given these documents, answer the question.\nDocuments:
rag_pipeline = Pipeline()
rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever")
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
rag_pipeline.add_component(instance=GPTGenerator(api_key=OPENAI_API_KEY), name="llm")
rag_pipeline.add_component(instance=OpenAIGenerator(api_key=OPENAI_API_KEY), name="llm")
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
rag_pipeline.connect("retriever", "prompt_builder.documents")
rag_pipeline.connect("prompt_builder", "llm")

View File

@ -7,7 +7,7 @@ from canals.component.types import Variadic
from haystack import Pipeline, Document, component, default_to_dict, default_from_dict, DeserializationError
from haystack.document_stores import InMemoryDocumentStore
from haystack.components.retrievers import InMemoryBM25Retriever
from haystack.components.generators import GPTGenerator
from haystack.components.generators import OpenAIGenerator
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.others import Multiplexer
from haystack.components.routers.conditional_router import ConditionalRouter
@ -99,7 +99,7 @@ def self_correcting_pipeline():
),
name="prompt_builder",
)
rag_pipeline.add_component(instance=GPTGenerator(), name="llm")
rag_pipeline.add_component(instance=OpenAIGenerator(), name="llm")
rag_pipeline.add_component(
instance=ConditionalRouter(
routes=[

View File

@ -1,10 +1,11 @@
import os
from pathlib import Path
from haystack import Document
from haystack import Pipeline
from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.generators import GPTGenerator
from haystack.components.generators import OpenAIGenerator
from haystack.components.retrievers import InMemoryBM25Retriever
from haystack.document_stores import InMemoryDocumentStore
@ -22,7 +23,7 @@ prompt_template = """
rag_pipeline = Pipeline()
rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
rag_pipeline.connect("retriever", "prompt_builder.documents")
rag_pipeline.connect("prompt_builder", "llm")
@ -31,7 +32,7 @@ rag_pipeline.connect("llm.metadata", "answer_builder.metadata")
rag_pipeline.connect("retriever", "answer_builder.documents")
# Draw the pipeline
rag_pipeline.draw("./rag_pipeline.png")
rag_pipeline.draw(Path("./rag_pipeline.png"))
# Add Documents
documents = [
@ -43,7 +44,7 @@ documents = [
content="In certain parts of the world, like the Maldives, Puerto Rico, and San Diego, you can witness the phenomenon of bioluminescent waves."
),
]
rag_pipeline.get_component("retriever").document_store.write_documents(documents)
rag_pipeline.get_component("retriever").document_store.write_documents(documents) # type: ignore
# Run the pipeline
question = "How many languages are there?"

View File

@ -30,13 +30,13 @@ class DynamicPromptBuilder:
```python
from haystack.components.builders import DynamicPromptBuilder
from haystack.components.generators.chat import GPTChatGenerator
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack import Pipeline
# no parameter init, we don't use any runtime template variables
prompt_builder = DynamicPromptBuilder()
llm = GPTChatGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
llm = OpenAIChatGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
pipe = Pipeline()
pipe.add_component("prompt_builder", prompt_builder)
@ -88,14 +88,14 @@ class DynamicPromptBuilder:
```python
from haystack.components.builders import DynamicPromptBuilder
from haystack.components.generators.chat import GPTChatGenerator
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage, Document
from haystack import Pipeline, component
from typing import List
# we'll use documents runtime variable in our template, so we need to specify it in the init
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
llm = GPTChatGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
llm = OpenAIChatGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
@component
@ -135,7 +135,7 @@ class DynamicPromptBuilder:
```python
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"], chat_mode=False)
llm = GPTGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
llm = OpenAIGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
@component

View File

@ -1,5 +1,5 @@
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
from haystack.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
from haystack.components.generators.openai import GPTGenerator
from haystack.components.generators.openai import OpenAIGenerator, GPTGenerator
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator"]
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "OpenAIGenerator", "GPTGenerator"]

View File

@ -1,4 +1,4 @@
from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
from haystack.components.generators.chat.openai import GPTChatGenerator
from haystack.components.generators.chat.openai import OpenAIChatGenerator, GPTChatGenerator
__all__ = ["HuggingFaceTGIChatGenerator", "GPTChatGenerator"]
__all__ = ["HuggingFaceTGIChatGenerator", "OpenAIChatGenerator", "GPTChatGenerator"]

View File

@ -1,6 +1,7 @@
import dataclasses
import json
import logging
import warnings
from typing import Optional, List, Callable, Dict, Any, Union
from openai import OpenAI, Stream
@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
@component
class GPTChatGenerator:
class OpenAIChatGenerator:
"""
Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo
family of models accessed through the chat completions API endpoint.
@ -27,12 +28,12 @@ class GPTChatGenerator:
[documentation](https://platform.openai.com/docs/api-reference/chat).
```python
from haystack.components.generators.chat import GPTChatGenerator
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
client = GPTChatGenerator()
client = OpenAIChatGenerator()
response = client.run(messages)
print(response)
@ -66,7 +67,7 @@ class GPTChatGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of ChatGPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
@ -126,7 +127,7 @@ class GPTChatGenerator:
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GPTChatGenerator":
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
@ -277,3 +278,29 @@ class GPTChatGenerator:
logger.warning(
"The completion for index %s has been truncated due to the content filter.", message.meta["index"]
)
class GPTChatGenerator(OpenAIChatGenerator):
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
warnings.warn(
"GPTChatGenerator is deprecated and will be removed in the next beta release. "
"Please use OpenAIChatGenerator instead.",
UserWarning,
stacklevel=2,
)
super().__init__(
api_key=api_key,
model_name=model_name,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
organization=organization,
generation_kwargs=generation_kwargs,
)

View File

@ -1,6 +1,7 @@
import dataclasses
import json
import logging
import warnings
from typing import Optional, List, Callable, Dict, Any, Union
from openai import OpenAI, Stream
@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
@component
class GPTGenerator:
class OpenAIGenerator:
"""
Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo
family of models.
@ -27,8 +28,8 @@ class GPTGenerator:
[documentation](https://platform.openai.com/docs/api-reference/chat).
```python
from haystack.components.generators import GPTGenerator
client = GPTGenerator()
from haystack.components.generators import OpenAIGenerator
client = OpenAIGenerator()
response = client.run("What's Natural Language Processing? Be brief.")
print(response)
@ -59,7 +60,7 @@ class GPTGenerator:
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Creates an instance of GPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
@ -123,7 +124,7 @@ class GPTGenerator:
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator":
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIGenerator":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
@ -279,3 +280,31 @@ class GPTGenerator:
logger.warning(
"The completion for index %s has been truncated due to the content filter.", message.meta["index"]
)
class GPTGenerator(OpenAIGenerator):
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
system_prompt: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
warnings.warn(
"GPTGenerator is deprecated and will be removed in the next beta release. "
"Please use OpenAIGenerator instead.",
UserWarning,
stacklevel=2,
)
super().__init__(
api_key=api_key,
model_name=model_name,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
organization=organization,
system_prompt=system_prompt,
generation_kwargs=generation_kwargs,
)

View File

@ -8,7 +8,7 @@ from haystack import Pipeline
from haystack.components.builders.answer_builder import AnswerBuilder
from haystack.components.builders.prompt_builder import PromptBuilder
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.generators import GPTGenerator, HuggingFaceTGIGenerator
from haystack.components.generators import OpenAIGenerator, HuggingFaceTGIGenerator
from haystack.components.retrievers import InMemoryEmbeddingRetriever
from haystack.dataclasses import Answer
from haystack.document_stores import InMemoryDocumentStore, DocumentStore
@ -179,13 +179,13 @@ class _GeneratorResolver(ABC):
class _OpenAIResolved(_GeneratorResolver):
"""
Resolves the OpenAI GPTGenerator.
Resolves the OpenAIGenerator.
"""
def resolve(self, model_key: str, api_key: str) -> Any:
# does the model_key match the pattern OpenAI GPT pattern?
if re.match(r"^gpt-4-.*", model_key) or re.match(r"^gpt-3.5-.*", model_key):
return GPTGenerator(model_name=model_key, api_key=api_key)
return OpenAIGenerator(model_name=model_key, api_key=api_key)
return None

View File

@ -0,0 +1,4 @@
---
deprecations:
- |
Deprecate GPTGenerator and GPTChatGenerator. Replace them with OpenAIGenerator and OpenAIChatGenerator.

View File

@ -3,7 +3,7 @@ import os
import pytest
from openai import OpenAIError
from haystack.components.generators.chat import GPTChatGenerator
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.generators.utils import default_streaming_callback
from haystack.dataclasses import ChatMessage, StreamingChunk
@ -16,9 +16,9 @@ def chat_messages():
]
class TestGPTChatGenerator:
class TestOpenAIChatGenerator:
def test_init_default(self):
component = GPTChatGenerator(api_key="test-api-key")
component = OpenAIChatGenerator(api_key="test-api-key")
assert component.client.api_key == "test-api-key"
assert component.model_name == "gpt-3.5-turbo"
assert component.streaming_callback is None
@ -27,10 +27,10 @@ class TestGPTChatGenerator:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError):
GPTChatGenerator()
OpenAIChatGenerator()
def test_init_with_parameters(self):
component = GPTChatGenerator(
component = OpenAIChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
streaming_callback=default_streaming_callback,
@ -43,10 +43,10 @@ class TestGPTChatGenerator:
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self):
component = GPTChatGenerator(api_key="test-api-key")
component = OpenAIChatGenerator(api_key="test-api-key")
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.chat.openai.GPTChatGenerator",
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-3.5-turbo",
"organization": None,
@ -57,7 +57,7 @@ class TestGPTChatGenerator:
}
def test_to_dict_with_parameters(self):
component = GPTChatGenerator(
component = OpenAIChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
streaming_callback=default_streaming_callback,
@ -66,7 +66,7 @@ class TestGPTChatGenerator:
)
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.chat.openai.GPTChatGenerator",
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"organization": None,
@ -77,7 +77,7 @@ class TestGPTChatGenerator:
}
def test_to_dict_with_lambda_streaming_callback(self):
component = GPTChatGenerator(
component = OpenAIChatGenerator(
api_key="test-api-key",
model_name="gpt-4",
streaming_callback=lambda x: x,
@ -86,7 +86,7 @@ class TestGPTChatGenerator:
)
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.chat.openai.GPTChatGenerator",
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"organization": None,
@ -98,7 +98,7 @@ class TestGPTChatGenerator:
def test_from_dict(self):
data = {
"type": "haystack.components.generators.chat.openai.GPTChatGenerator",
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"api_base_url": "test-base-url",
@ -106,7 +106,7 @@ class TestGPTChatGenerator:
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = GPTChatGenerator.from_dict(data)
component = OpenAIChatGenerator.from_dict(data)
assert component.model_name == "gpt-4"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
@ -115,7 +115,7 @@ class TestGPTChatGenerator:
def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "haystack.components.generators.chat.openai.GPTChatGenerator",
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"model_name": "gpt-4",
"organization": None,
@ -125,10 +125,10 @@ class TestGPTChatGenerator:
},
}
with pytest.raises(OpenAIError):
GPTChatGenerator.from_dict(data)
OpenAIChatGenerator.from_dict(data)
def test_run(self, chat_messages, mock_chat_completion):
component = GPTChatGenerator()
component = OpenAIChatGenerator()
response = component.run(chat_messages)
# check that the component returns the correct ChatMessage response
@ -139,7 +139,7 @@ class TestGPTChatGenerator:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
def test_run_with_params(self, chat_messages, mock_chat_completion):
component = GPTChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5})
component = OpenAIChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5})
response = component.run(chat_messages)
# check that the component calls the OpenAI API with the correct parameters
@ -161,7 +161,7 @@ class TestGPTChatGenerator:
nonlocal streaming_callback_called
streaming_callback_called = True
component = GPTChatGenerator(streaming_callback=streaming_callback)
component = OpenAIChatGenerator(streaming_callback=streaming_callback)
response = component.run(chat_messages)
# check we called the streaming callback
@ -176,7 +176,7 @@ class TestGPTChatGenerator:
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
def test_check_abnormal_completions(self, caplog):
component = GPTChatGenerator(api_key="test-api-key")
component = OpenAIChatGenerator(api_key="test-api-key")
messages = [
ChatMessage.from_assistant(
"", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
@ -208,7 +208,7 @@ class TestGPTChatGenerator:
@pytest.mark.integration
def test_live_run(self):
chat_messages = [ChatMessage.from_user("What's the capital of France")]
component = GPTChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1})
component = OpenAIChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1})
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
@ -222,7 +222,9 @@ class TestGPTChatGenerator:
)
@pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages):
component = GPTChatGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
component = OpenAIChatGenerator(
model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")
)
with pytest.raises(OpenAIError):
component.run(chat_messages)
@ -242,7 +244,7 @@ class TestGPTChatGenerator:
self.responses += chunk.content if chunk.content else ""
callback = Callback()
component = GPTChatGenerator(streaming_callback=callback)
component = OpenAIChatGenerator(streaming_callback=callback)
results = component.run([ChatMessage.from_user("What's the capital of France?")])
assert len(results["replies"]) == 1

View File

@ -4,14 +4,14 @@ from typing import List
import pytest
from openai import OpenAIError
from haystack.components.generators import GPTGenerator
from haystack.components.generators import OpenAIGenerator
from haystack.components.generators.utils import default_streaming_callback
from haystack.dataclasses import StreamingChunk, ChatMessage
class TestGPTGenerator:
class TestOpenAIGenerator:
def test_init_default(self):
component = GPTGenerator(api_key="test-api-key")
component = OpenAIGenerator(api_key="test-api-key")
assert component.client.api_key == "test-api-key"
assert component.model_name == "gpt-3.5-turbo"
assert component.streaming_callback is None
@ -20,10 +20,10 @@ class TestGPTGenerator:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError):
GPTGenerator()
OpenAIGenerator()
def test_init_with_parameters(self):
component = GPTGenerator(
component = OpenAIGenerator(
api_key="test-api-key",
model_name="gpt-4",
streaming_callback=default_streaming_callback,
@ -36,10 +36,10 @@ class TestGPTGenerator:
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self):
component = GPTGenerator(api_key="test-api-key")
component = OpenAIGenerator(api_key="test-api-key")
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.openai.GPTGenerator",
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"model_name": "gpt-3.5-turbo",
"streaming_callback": None,
@ -50,7 +50,7 @@ class TestGPTGenerator:
}
def test_to_dict_with_parameters(self):
component = GPTGenerator(
component = OpenAIGenerator(
api_key="test-api-key",
model_name="gpt-4",
streaming_callback=default_streaming_callback,
@ -59,7 +59,7 @@ class TestGPTGenerator:
)
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.openai.GPTGenerator",
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"model_name": "gpt-4",
"system_prompt": None,
@ -70,7 +70,7 @@ class TestGPTGenerator:
}
def test_to_dict_with_lambda_streaming_callback(self):
component = GPTGenerator(
component = OpenAIGenerator(
api_key="test-api-key",
model_name="gpt-4",
streaming_callback=lambda x: x,
@ -79,7 +79,7 @@ class TestGPTGenerator:
)
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.openai.GPTGenerator",
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"model_name": "gpt-4",
"system_prompt": None,
@ -92,7 +92,7 @@ class TestGPTGenerator:
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
data = {
"type": "haystack.components.generators.openai.GPTGenerator",
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"model_name": "gpt-4",
"system_prompt": None,
@ -101,7 +101,7 @@ class TestGPTGenerator:
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = GPTGenerator.from_dict(data)
component = OpenAIGenerator.from_dict(data)
assert component.model_name == "gpt-4"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
@ -110,7 +110,7 @@ class TestGPTGenerator:
def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "haystack.components.generators.openai.GPTGenerator",
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"model_name": "gpt-4",
"api_base_url": "test-base-url",
@ -119,10 +119,10 @@ class TestGPTGenerator:
},
}
with pytest.raises(OpenAIError):
GPTGenerator.from_dict(data)
OpenAIGenerator.from_dict(data)
def test_run(self, mock_chat_completion):
component = GPTGenerator(api_key="test-api-key")
component = OpenAIGenerator(api_key="test-api-key")
response = component.run("What's Natural Language Processing?")
# check that the component returns the correct ChatMessage response
@ -139,7 +139,7 @@ class TestGPTGenerator:
nonlocal streaming_callback_called
streaming_callback_called = True
component = GPTGenerator(streaming_callback=streaming_callback)
component = OpenAIGenerator(streaming_callback=streaming_callback)
response = component.run("Come on, stream!")
# check we called the streaming callback
@ -153,7 +153,7 @@ class TestGPTGenerator:
assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk
def test_run_with_params(self, mock_chat_completion):
component = GPTGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5})
component = OpenAIGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5})
response = component.run("What's Natural Language Processing?")
# check that the component calls the OpenAI API with the correct parameters
@ -169,7 +169,7 @@ class TestGPTGenerator:
assert [isinstance(reply, str) for reply in response["replies"]]
def test_check_abnormal_completions(self, caplog):
component = GPTGenerator(api_key="test-api-key")
component = OpenAIGenerator(api_key="test-api-key")
# underlying implementation uses ChatMessage objects so we have to use them here
messages: List[ChatMessage] = []
@ -202,7 +202,7 @@ class TestGPTGenerator:
)
@pytest.mark.integration
def test_live_run(self):
component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
component = OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
results = component.run("What's the capital of France?")
assert len(results["replies"]) == 1
assert len(results["meta"]) == 1
@ -224,7 +224,7 @@ class TestGPTGenerator:
)
@pytest.mark.integration
def test_live_run_wrong_model(self):
component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
component = OpenAIGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
with pytest.raises(OpenAIError):
component.run("Whatever")
@ -244,7 +244,7 @@ class TestGPTGenerator:
self.responses += chunk.content if chunk.content else ""
callback = Callback()
component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
component = OpenAIGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
results = component.run("What's the capital of France?")
assert len(results["replies"]) == 1