mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-09 13:46:54 +00:00
chore: Rename GPT generators, deprecate old names (#6626)
This commit is contained in:
parent
c0f1dab454
commit
506ab81d26
@ -7,7 +7,7 @@ from haystack.document_stores import InMemoryDocumentStore
|
||||
from haystack.components.writers import DocumentWriter
|
||||
from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
||||
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
|
||||
from haystack.components.generators import GPTGenerator
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack.components.builders.answer_builder import AnswerBuilder
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
|
||||
@ -30,7 +30,7 @@ def test_bm25_rag_pipeline(tmp_path):
|
||||
rag_pipeline = Pipeline()
|
||||
rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
||||
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
||||
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
||||
rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
||||
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
|
||||
rag_pipeline.connect("retriever", "prompt_builder.documents")
|
||||
rag_pipeline.connect("prompt_builder", "llm")
|
||||
@ -102,7 +102,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
|
||||
instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever"
|
||||
)
|
||||
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
||||
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
||||
rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
||||
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
|
||||
rag_pipeline.connect("text_embedder", "retriever")
|
||||
rag_pipeline.connect("retriever", "prompt_builder.documents")
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.components.generators.openai import GPTGenerator
|
||||
from haystack.components.generators.openai import OpenAIGenerator
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
import random
|
||||
from haystack import component
|
||||
@ -83,8 +83,8 @@ prompt_template = """
|
||||
# Let's build the pipeline (Make sure to set OPENAI_API_KEY as an environment variable)
|
||||
pipeline = Pipeline(max_loops_allowed=5)
|
||||
pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
||||
pipeline.add_component(instance=GPTGenerator(), name="llm")
|
||||
pipeline.add_component(instance=OutputParser(pydantic_model=CitiesData), name="output_parser")
|
||||
pipeline.add_component(instance=OpenAIGenerator(), name="llm")
|
||||
pipeline.add_component(instance=OutputParser(pydantic_model=CitiesData), name="output_parser") # type: ignore
|
||||
|
||||
pipeline.connect("prompt_builder", "llm")
|
||||
pipeline.connect("llm", "output_parser")
|
||||
|
||||
@ -2,7 +2,7 @@ import os
|
||||
from haystack import Pipeline, Document
|
||||
from haystack.document_stores import InMemoryDocumentStore
|
||||
from haystack.components.retrievers import InMemoryBM25Retriever
|
||||
from haystack.components.generators import GPTGenerator
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack.components.builders.answer_builder import AnswerBuilder
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
|
||||
@ -20,7 +20,7 @@ documents = [
|
||||
]
|
||||
document_store.write_documents(documents)
|
||||
|
||||
# Build a RAG pipeline with a Retriever to get relevant documents to the query and a GPTGenerator interacting with LLMs using a custom prompt.
|
||||
# Build a RAG pipeline with a Retriever to get relevant documents to the query and a OpenAIGenerator interacting with LLMs using a custom prompt.
|
||||
prompt_template = """
|
||||
Given these documents, answer the question.\nDocuments:
|
||||
{% for doc in documents %}
|
||||
@ -33,7 +33,7 @@ Given these documents, answer the question.\nDocuments:
|
||||
rag_pipeline = Pipeline()
|
||||
rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever")
|
||||
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
||||
rag_pipeline.add_component(instance=GPTGenerator(api_key=OPENAI_API_KEY), name="llm")
|
||||
rag_pipeline.add_component(instance=OpenAIGenerator(api_key=OPENAI_API_KEY), name="llm")
|
||||
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
|
||||
rag_pipeline.connect("retriever", "prompt_builder.documents")
|
||||
rag_pipeline.connect("prompt_builder", "llm")
|
||||
|
||||
@ -7,7 +7,7 @@ from canals.component.types import Variadic
|
||||
from haystack import Pipeline, Document, component, default_to_dict, default_from_dict, DeserializationError
|
||||
from haystack.document_stores import InMemoryDocumentStore
|
||||
from haystack.components.retrievers import InMemoryBM25Retriever
|
||||
from haystack.components.generators import GPTGenerator
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
from haystack.components.others import Multiplexer
|
||||
from haystack.components.routers.conditional_router import ConditionalRouter
|
||||
@ -99,7 +99,7 @@ def self_correcting_pipeline():
|
||||
),
|
||||
name="prompt_builder",
|
||||
)
|
||||
rag_pipeline.add_component(instance=GPTGenerator(), name="llm")
|
||||
rag_pipeline.add_component(instance=OpenAIGenerator(), name="llm")
|
||||
rag_pipeline.add_component(
|
||||
instance=ConditionalRouter(
|
||||
routes=[
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from haystack import Document
|
||||
from haystack import Pipeline
|
||||
from haystack.components.builders.answer_builder import AnswerBuilder
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
from haystack.components.generators import GPTGenerator
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack.components.retrievers import InMemoryBM25Retriever
|
||||
from haystack.document_stores import InMemoryDocumentStore
|
||||
|
||||
@ -22,7 +23,7 @@ prompt_template = """
|
||||
rag_pipeline = Pipeline()
|
||||
rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever")
|
||||
rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder")
|
||||
rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
||||
rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm")
|
||||
rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder")
|
||||
rag_pipeline.connect("retriever", "prompt_builder.documents")
|
||||
rag_pipeline.connect("prompt_builder", "llm")
|
||||
@ -31,7 +32,7 @@ rag_pipeline.connect("llm.metadata", "answer_builder.metadata")
|
||||
rag_pipeline.connect("retriever", "answer_builder.documents")
|
||||
|
||||
# Draw the pipeline
|
||||
rag_pipeline.draw("./rag_pipeline.png")
|
||||
rag_pipeline.draw(Path("./rag_pipeline.png"))
|
||||
|
||||
# Add Documents
|
||||
documents = [
|
||||
@ -43,7 +44,7 @@ documents = [
|
||||
content="In certain parts of the world, like the Maldives, Puerto Rico, and San Diego, you can witness the phenomenon of bioluminescent waves."
|
||||
),
|
||||
]
|
||||
rag_pipeline.get_component("retriever").document_store.write_documents(documents)
|
||||
rag_pipeline.get_component("retriever").document_store.write_documents(documents) # type: ignore
|
||||
|
||||
# Run the pipeline
|
||||
question = "How many languages are there?"
|
||||
|
||||
@ -30,13 +30,13 @@ class DynamicPromptBuilder:
|
||||
|
||||
```python
|
||||
from haystack.components.builders import DynamicPromptBuilder
|
||||
from haystack.components.generators.chat import GPTChatGenerator
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack import Pipeline
|
||||
|
||||
# no parameter init, we don't use any runtime template variables
|
||||
prompt_builder = DynamicPromptBuilder()
|
||||
llm = GPTChatGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
|
||||
llm = OpenAIChatGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("prompt_builder", prompt_builder)
|
||||
@ -88,14 +88,14 @@ class DynamicPromptBuilder:
|
||||
|
||||
```python
|
||||
from haystack.components.builders import DynamicPromptBuilder
|
||||
from haystack.components.generators.chat import GPTChatGenerator
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.dataclasses import ChatMessage, Document
|
||||
from haystack import Pipeline, component
|
||||
from typing import List
|
||||
|
||||
# we'll use documents runtime variable in our template, so we need to specify it in the init
|
||||
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
|
||||
llm = GPTChatGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
|
||||
llm = OpenAIChatGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
|
||||
|
||||
|
||||
@component
|
||||
@ -135,7 +135,7 @@ class DynamicPromptBuilder:
|
||||
|
||||
```python
|
||||
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"], chat_mode=False)
|
||||
llm = GPTGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
|
||||
llm = OpenAIGenerator(api_key="<your-api-key>", model_name="gpt-3.5-turbo")
|
||||
|
||||
|
||||
@component
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
|
||||
from haystack.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
|
||||
from haystack.components.generators.openai import GPTGenerator
|
||||
from haystack.components.generators.openai import OpenAIGenerator, GPTGenerator
|
||||
|
||||
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator"]
|
||||
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "OpenAIGenerator", "GPTGenerator"]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
|
||||
from haystack.components.generators.chat.openai import GPTChatGenerator
|
||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator, GPTChatGenerator
|
||||
|
||||
__all__ = ["HuggingFaceTGIChatGenerator", "GPTChatGenerator"]
|
||||
__all__ = ["HuggingFaceTGIChatGenerator", "OpenAIChatGenerator", "GPTChatGenerator"]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Optional, List, Callable, Dict, Any, Union
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class GPTChatGenerator:
|
||||
class OpenAIChatGenerator:
|
||||
"""
|
||||
Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo
|
||||
family of models accessed through the chat completions API endpoint.
|
||||
@ -27,12 +28,12 @@ class GPTChatGenerator:
|
||||
[documentation](https://platform.openai.com/docs/api-reference/chat).
|
||||
|
||||
```python
|
||||
from haystack.components.generators.chat import GPTChatGenerator
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
|
||||
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
|
||||
|
||||
client = GPTChatGenerator()
|
||||
client = OpenAIChatGenerator()
|
||||
response = client.run(messages)
|
||||
print(response)
|
||||
|
||||
@ -66,7 +67,7 @@ class GPTChatGenerator:
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Creates an instance of ChatGPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
|
||||
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
|
||||
GPT-3.5 model.
|
||||
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
@ -126,7 +127,7 @@ class GPTChatGenerator:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GPTChatGenerator":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
:param data: The dictionary representation of this component.
|
||||
@ -277,3 +278,29 @@ class GPTChatGenerator:
|
||||
logger.warning(
|
||||
"The completion for index %s has been truncated due to the content filter.", message.meta["index"]
|
||||
)
|
||||
|
||||
|
||||
class GPTChatGenerator(OpenAIChatGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
warnings.warn(
|
||||
"GPTChatGenerator is deprecated and will be removed in the next beta release. "
|
||||
"Please use OpenAIChatGenerator instead.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
streaming_callback=streaming_callback,
|
||||
api_base_url=api_base_url,
|
||||
organization=organization,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Optional, List, Callable, Dict, Any, Union
|
||||
|
||||
from openai import OpenAI, Stream
|
||||
@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class GPTGenerator:
|
||||
class OpenAIGenerator:
|
||||
"""
|
||||
Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo
|
||||
family of models.
|
||||
@ -27,8 +28,8 @@ class GPTGenerator:
|
||||
[documentation](https://platform.openai.com/docs/api-reference/chat).
|
||||
|
||||
```python
|
||||
from haystack.components.generators import GPTGenerator
|
||||
client = GPTGenerator()
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
client = OpenAIGenerator()
|
||||
response = client.run("What's Natural Language Processing? Be brief.")
|
||||
print(response)
|
||||
|
||||
@ -59,7 +60,7 @@ class GPTGenerator:
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Creates an instance of GPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
|
||||
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's
|
||||
GPT-3.5 model.
|
||||
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
@ -123,7 +124,7 @@ class GPTGenerator:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator":
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
:param data: The dictionary representation of this component.
|
||||
@ -279,3 +280,31 @@ class GPTGenerator:
|
||||
logger.warning(
|
||||
"The completion for index %s has been truncated due to the content filter.", message.meta["index"]
|
||||
)
|
||||
|
||||
|
||||
class GPTGenerator(OpenAIGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
warnings.warn(
|
||||
"GPTGenerator is deprecated and will be removed in the next beta release. "
|
||||
"Please use OpenAIGenerator instead.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
streaming_callback=streaming_callback,
|
||||
api_base_url=api_base_url,
|
||||
organization=organization,
|
||||
system_prompt=system_prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
|
||||
@ -8,7 +8,7 @@ from haystack import Pipeline
|
||||
from haystack.components.builders.answer_builder import AnswerBuilder
|
||||
from haystack.components.builders.prompt_builder import PromptBuilder
|
||||
from haystack.components.embedders import SentenceTransformersTextEmbedder
|
||||
from haystack.components.generators import GPTGenerator, HuggingFaceTGIGenerator
|
||||
from haystack.components.generators import OpenAIGenerator, HuggingFaceTGIGenerator
|
||||
from haystack.components.retrievers import InMemoryEmbeddingRetriever
|
||||
from haystack.dataclasses import Answer
|
||||
from haystack.document_stores import InMemoryDocumentStore, DocumentStore
|
||||
@ -179,13 +179,13 @@ class _GeneratorResolver(ABC):
|
||||
|
||||
class _OpenAIResolved(_GeneratorResolver):
|
||||
"""
|
||||
Resolves the OpenAI GPTGenerator.
|
||||
Resolves the OpenAIGenerator.
|
||||
"""
|
||||
|
||||
def resolve(self, model_key: str, api_key: str) -> Any:
|
||||
# does the model_key match the pattern OpenAI GPT pattern?
|
||||
if re.match(r"^gpt-4-.*", model_key) or re.match(r"^gpt-3.5-.*", model_key):
|
||||
return GPTGenerator(model_name=model_key, api_key=api_key)
|
||||
return OpenAIGenerator(model_name=model_key, api_key=api_key)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
deprecations:
|
||||
- |
|
||||
Deprecate GPTGenerator and GPTChatGenerator. Replace them with OpenAIGenerator and OpenAIChatGenerator.
|
||||
@ -3,7 +3,7 @@ import os
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
|
||||
from haystack.components.generators.chat import GPTChatGenerator
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.components.generators.utils import default_streaming_callback
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
|
||||
@ -16,9 +16,9 @@ def chat_messages():
|
||||
]
|
||||
|
||||
|
||||
class TestGPTChatGenerator:
|
||||
class TestOpenAIChatGenerator:
|
||||
def test_init_default(self):
|
||||
component = GPTChatGenerator(api_key="test-api-key")
|
||||
component = OpenAIChatGenerator(api_key="test-api-key")
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-3.5-turbo"
|
||||
assert component.streaming_callback is None
|
||||
@ -27,10 +27,10 @@ class TestGPTChatGenerator:
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(OpenAIError):
|
||||
GPTChatGenerator()
|
||||
OpenAIChatGenerator()
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
component = GPTChatGenerator(
|
||||
component = OpenAIChatGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
streaming_callback=default_streaming_callback,
|
||||
@ -43,10 +43,10 @@ class TestGPTChatGenerator:
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
def test_to_dict_default(self):
|
||||
component = GPTChatGenerator(api_key="test-api-key")
|
||||
component = OpenAIChatGenerator(api_key="test-api-key")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.openai.GPTChatGenerator",
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"organization": None,
|
||||
@ -57,7 +57,7 @@ class TestGPTChatGenerator:
|
||||
}
|
||||
|
||||
def test_to_dict_with_parameters(self):
|
||||
component = GPTChatGenerator(
|
||||
component = OpenAIChatGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
streaming_callback=default_streaming_callback,
|
||||
@ -66,7 +66,7 @@ class TestGPTChatGenerator:
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.openai.GPTChatGenerator",
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"organization": None,
|
||||
@ -77,7 +77,7 @@ class TestGPTChatGenerator:
|
||||
}
|
||||
|
||||
def test_to_dict_with_lambda_streaming_callback(self):
|
||||
component = GPTChatGenerator(
|
||||
component = OpenAIChatGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
streaming_callback=lambda x: x,
|
||||
@ -86,7 +86,7 @@ class TestGPTChatGenerator:
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.openai.GPTChatGenerator",
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"organization": None,
|
||||
@ -98,7 +98,7 @@ class TestGPTChatGenerator:
|
||||
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "haystack.components.generators.chat.openai.GPTChatGenerator",
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"api_base_url": "test-base-url",
|
||||
@ -106,7 +106,7 @@ class TestGPTChatGenerator:
|
||||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||
},
|
||||
}
|
||||
component = GPTChatGenerator.from_dict(data)
|
||||
component = OpenAIChatGenerator.from_dict(data)
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.streaming_callback is default_streaming_callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
@ -115,7 +115,7 @@ class TestGPTChatGenerator:
|
||||
def test_from_dict_fail_wo_env_var(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
data = {
|
||||
"type": "haystack.components.generators.chat.openai.GPTChatGenerator",
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"organization": None,
|
||||
@ -125,10 +125,10 @@ class TestGPTChatGenerator:
|
||||
},
|
||||
}
|
||||
with pytest.raises(OpenAIError):
|
||||
GPTChatGenerator.from_dict(data)
|
||||
OpenAIChatGenerator.from_dict(data)
|
||||
|
||||
def test_run(self, chat_messages, mock_chat_completion):
|
||||
component = GPTChatGenerator()
|
||||
component = OpenAIChatGenerator()
|
||||
response = component.run(chat_messages)
|
||||
|
||||
# check that the component returns the correct ChatMessage response
|
||||
@ -139,7 +139,7 @@ class TestGPTChatGenerator:
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
def test_run_with_params(self, chat_messages, mock_chat_completion):
|
||||
component = GPTChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5})
|
||||
component = OpenAIChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5})
|
||||
response = component.run(chat_messages)
|
||||
|
||||
# check that the component calls the OpenAI API with the correct parameters
|
||||
@ -161,7 +161,7 @@ class TestGPTChatGenerator:
|
||||
nonlocal streaming_callback_called
|
||||
streaming_callback_called = True
|
||||
|
||||
component = GPTChatGenerator(streaming_callback=streaming_callback)
|
||||
component = OpenAIChatGenerator(streaming_callback=streaming_callback)
|
||||
response = component.run(chat_messages)
|
||||
|
||||
# check we called the streaming callback
|
||||
@ -176,7 +176,7 @@ class TestGPTChatGenerator:
|
||||
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
|
||||
|
||||
def test_check_abnormal_completions(self, caplog):
|
||||
component = GPTChatGenerator(api_key="test-api-key")
|
||||
component = OpenAIChatGenerator(api_key="test-api-key")
|
||||
messages = [
|
||||
ChatMessage.from_assistant(
|
||||
"", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
|
||||
@ -208,7 +208,7 @@ class TestGPTChatGenerator:
|
||||
@pytest.mark.integration
|
||||
def test_live_run(self):
|
||||
chat_messages = [ChatMessage.from_user("What's the capital of France")]
|
||||
component = GPTChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1})
|
||||
component = OpenAIChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1})
|
||||
results = component.run(chat_messages)
|
||||
assert len(results["replies"]) == 1
|
||||
message: ChatMessage = results["replies"][0]
|
||||
@ -222,7 +222,9 @@ class TestGPTChatGenerator:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_wrong_model(self, chat_messages):
|
||||
component = GPTChatGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
component = OpenAIChatGenerator(
|
||||
model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")
|
||||
)
|
||||
with pytest.raises(OpenAIError):
|
||||
component.run(chat_messages)
|
||||
|
||||
@ -242,7 +244,7 @@ class TestGPTChatGenerator:
|
||||
self.responses += chunk.content if chunk.content else ""
|
||||
|
||||
callback = Callback()
|
||||
component = GPTChatGenerator(streaming_callback=callback)
|
||||
component = OpenAIChatGenerator(streaming_callback=callback)
|
||||
results = component.run([ChatMessage.from_user("What's the capital of France?")])
|
||||
|
||||
assert len(results["replies"]) == 1
|
||||
|
||||
@ -4,14 +4,14 @@ from typing import List
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
|
||||
from haystack.components.generators import GPTGenerator
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack.components.generators.utils import default_streaming_callback
|
||||
from haystack.dataclasses import StreamingChunk, ChatMessage
|
||||
|
||||
|
||||
class TestGPTGenerator:
|
||||
class TestOpenAIGenerator:
|
||||
def test_init_default(self):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
component = OpenAIGenerator(api_key="test-api-key")
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.model_name == "gpt-3.5-turbo"
|
||||
assert component.streaming_callback is None
|
||||
@ -20,10 +20,10 @@ class TestGPTGenerator:
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(OpenAIError):
|
||||
GPTGenerator()
|
||||
OpenAIGenerator()
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
component = GPTGenerator(
|
||||
component = OpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
streaming_callback=default_streaming_callback,
|
||||
@ -36,10 +36,10 @@ class TestGPTGenerator:
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
def test_to_dict_default(self):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
component = OpenAIGenerator(api_key="test-api-key")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.openai.GPTGenerator",
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"streaming_callback": None,
|
||||
@ -50,7 +50,7 @@ class TestGPTGenerator:
|
||||
}
|
||||
|
||||
def test_to_dict_with_parameters(self):
|
||||
component = GPTGenerator(
|
||||
component = OpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
streaming_callback=default_streaming_callback,
|
||||
@ -59,7 +59,7 @@ class TestGPTGenerator:
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.openai.GPTGenerator",
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"system_prompt": None,
|
||||
@ -70,7 +70,7 @@ class TestGPTGenerator:
|
||||
}
|
||||
|
||||
def test_to_dict_with_lambda_streaming_callback(self):
|
||||
component = GPTGenerator(
|
||||
component = OpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
model_name="gpt-4",
|
||||
streaming_callback=lambda x: x,
|
||||
@ -79,7 +79,7 @@ class TestGPTGenerator:
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.openai.GPTGenerator",
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"system_prompt": None,
|
||||
@ -92,7 +92,7 @@ class TestGPTGenerator:
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||
data = {
|
||||
"type": "haystack.components.generators.openai.GPTGenerator",
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"system_prompt": None,
|
||||
@ -101,7 +101,7 @@ class TestGPTGenerator:
|
||||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||
},
|
||||
}
|
||||
component = GPTGenerator.from_dict(data)
|
||||
component = OpenAIGenerator.from_dict(data)
|
||||
assert component.model_name == "gpt-4"
|
||||
assert component.streaming_callback is default_streaming_callback
|
||||
assert component.api_base_url == "test-base-url"
|
||||
@ -110,7 +110,7 @@ class TestGPTGenerator:
|
||||
def test_from_dict_fail_wo_env_var(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
data = {
|
||||
"type": "haystack.components.generators.openai.GPTGenerator",
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"model_name": "gpt-4",
|
||||
"api_base_url": "test-base-url",
|
||||
@ -119,10 +119,10 @@ class TestGPTGenerator:
|
||||
},
|
||||
}
|
||||
with pytest.raises(OpenAIError):
|
||||
GPTGenerator.from_dict(data)
|
||||
OpenAIGenerator.from_dict(data)
|
||||
|
||||
def test_run(self, mock_chat_completion):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
component = OpenAIGenerator(api_key="test-api-key")
|
||||
response = component.run("What's Natural Language Processing?")
|
||||
|
||||
# check that the component returns the correct ChatMessage response
|
||||
@ -139,7 +139,7 @@ class TestGPTGenerator:
|
||||
nonlocal streaming_callback_called
|
||||
streaming_callback_called = True
|
||||
|
||||
component = GPTGenerator(streaming_callback=streaming_callback)
|
||||
component = OpenAIGenerator(streaming_callback=streaming_callback)
|
||||
response = component.run("Come on, stream!")
|
||||
|
||||
# check we called the streaming callback
|
||||
@ -153,7 +153,7 @@ class TestGPTGenerator:
|
||||
assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk
|
||||
|
||||
def test_run_with_params(self, mock_chat_completion):
|
||||
component = GPTGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5})
|
||||
component = OpenAIGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5})
|
||||
response = component.run("What's Natural Language Processing?")
|
||||
|
||||
# check that the component calls the OpenAI API with the correct parameters
|
||||
@ -169,7 +169,7 @@ class TestGPTGenerator:
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
def test_check_abnormal_completions(self, caplog):
|
||||
component = GPTGenerator(api_key="test-api-key")
|
||||
component = OpenAIGenerator(api_key="test-api-key")
|
||||
|
||||
# underlying implementation uses ChatMessage objects so we have to use them here
|
||||
messages: List[ChatMessage] = []
|
||||
@ -202,7 +202,7 @@ class TestGPTGenerator:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run(self):
|
||||
component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
component = OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
results = component.run("What's the capital of France?")
|
||||
assert len(results["replies"]) == 1
|
||||
assert len(results["meta"]) == 1
|
||||
@ -224,7 +224,7 @@ class TestGPTGenerator:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_wrong_model(self):
|
||||
component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
component = OpenAIGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
with pytest.raises(OpenAIError):
|
||||
component.run("Whatever")
|
||||
|
||||
@ -244,7 +244,7 @@ class TestGPTGenerator:
|
||||
self.responses += chunk.content if chunk.content else ""
|
||||
|
||||
callback = Callback()
|
||||
component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
|
||||
component = OpenAIGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
|
||||
results = component.run("What's the capital of France?")
|
||||
|
||||
assert len(results["replies"]) == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user