From 506ab81d266b6cbb67302189d3d331bb4e95dac7 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 22 Dec 2023 19:37:29 +0100 Subject: [PATCH] chore: Rename GPT generators, deprecate old names (#6626) --- e2e/pipelines/test_rag_pipelines.py | 6 +-- examples/pipeline_loop_to_autocorrect_json.py | 6 +-- examples/pipelines/rag_pipeline.py | 6 +-- examples/rag/rag_self_correction.py | 4 +- examples/retrievers/in_memory_bm25_rag.py | 9 ++-- .../builders/dynamic_prompt_builder.py | 10 ++-- haystack/components/generators/__init__.py | 4 +- .../components/generators/chat/__init__.py | 4 +- haystack/components/generators/chat/openai.py | 37 +++++++++++++-- haystack/components/generators/openai.py | 39 ++++++++++++++-- haystack/pipeline_utils/rag.py | 6 +-- ...ename-gpt-generators-f25011d251fafd6d.yaml | 4 ++ .../components/generators/chat/test_openai.py | 46 ++++++++++--------- test/components/generators/test_openai.py | 44 +++++++++--------- 14 files changed, 144 insertions(+), 81 deletions(-) create mode 100644 releasenotes/notes/rename-gpt-generators-f25011d251fafd6d.yaml diff --git a/e2e/pipelines/test_rag_pipelines.py b/e2e/pipelines/test_rag_pipelines.py index db2cdd7f8..fd804d833 100644 --- a/e2e/pipelines/test_rag_pipelines.py +++ b/e2e/pipelines/test_rag_pipelines.py @@ -7,7 +7,7 @@ from haystack.document_stores import InMemoryDocumentStore from haystack.components.writers import DocumentWriter from haystack.components.retrievers import InMemoryBM25Retriever, InMemoryEmbeddingRetriever from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder -from haystack.components.generators import GPTGenerator +from haystack.components.generators import OpenAIGenerator from haystack.components.builders.answer_builder import AnswerBuilder from haystack.components.builders.prompt_builder import PromptBuilder @@ -30,7 +30,7 @@ def test_bm25_rag_pipeline(tmp_path): rag_pipeline = Pipeline() rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever") rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") - rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") + rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") rag_pipeline.connect("retriever", "prompt_builder.documents") rag_pipeline.connect("prompt_builder", "llm") @@ -102,7 +102,7 @@ def test_embedding_retrieval_rag_pipeline(tmp_path): instance=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()), name="retriever" ) rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") - rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") + rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") rag_pipeline.connect("text_embedder", "retriever") rag_pipeline.connect("retriever", "prompt_builder.documents") diff --git a/examples/pipeline_loop_to_autocorrect_json.py b/examples/pipeline_loop_to_autocorrect_json.py index 4128e3a6d..054402921 100644 --- a/examples/pipeline_loop_to_autocorrect_json.py +++ b/examples/pipeline_loop_to_autocorrect_json.py @@ -2,7 +2,7 @@ import json import os from haystack import Pipeline -from haystack.components.generators.openai import GPTGenerator +from haystack.components.generators.openai import OpenAIGenerator from haystack.components.builders.prompt_builder import PromptBuilder import random from haystack import component @@ -83,8 +83,8 @@ prompt_template = """ # Let's build the pipeline (Make sure to set OPENAI_API_KEY as an environment variable) pipeline = Pipeline(max_loops_allowed=5) pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") -pipeline.add_component(instance=GPTGenerator(), name="llm") -pipeline.add_component(instance=OutputParser(pydantic_model=CitiesData), name="output_parser") +pipeline.add_component(instance=OpenAIGenerator(), name="llm") +pipeline.add_component(instance=OutputParser(pydantic_model=CitiesData), name="output_parser") # type: ignore pipeline.connect("prompt_builder", "llm") pipeline.connect("llm", "output_parser") diff --git a/examples/pipelines/rag_pipeline.py b/examples/pipelines/rag_pipeline.py index 1e9151c34..eebc88661 100644 --- a/examples/pipelines/rag_pipeline.py +++ b/examples/pipelines/rag_pipeline.py @@ -2,7 +2,7 @@ import os from haystack import Pipeline, Document from haystack.document_stores import InMemoryDocumentStore from haystack.components.retrievers import InMemoryBM25Retriever -from haystack.components.generators import GPTGenerator +from haystack.components.generators import OpenAIGenerator from haystack.components.builders.answer_builder import AnswerBuilder from haystack.components.builders.prompt_builder import PromptBuilder @@ -20,7 +20,7 @@ documents = [ ] document_store.write_documents(documents) -# Build a RAG pipeline with a Retriever to get relevant documents to the query and a GPTGenerator interacting with LLMs using a custom prompt. +# Build a RAG pipeline with a Retriever to get relevant documents to the query and a OpenAIGenerator interacting with LLMs using a custom prompt. prompt_template = """ Given these documents, answer the question.\nDocuments: {% for doc in documents %} @@ -33,7 +33,7 @@ Given these documents, answer the question.\nDocuments: rag_pipeline = Pipeline() rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="retriever") rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") -rag_pipeline.add_component(instance=GPTGenerator(api_key=OPENAI_API_KEY), name="llm") +rag_pipeline.add_component(instance=OpenAIGenerator(api_key=OPENAI_API_KEY), name="llm") rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") rag_pipeline.connect("retriever", "prompt_builder.documents") rag_pipeline.connect("prompt_builder", "llm") diff --git a/examples/rag/rag_self_correction.py b/examples/rag/rag_self_correction.py index a3455dd12..83046ec4f 100644 --- a/examples/rag/rag_self_correction.py +++ b/examples/rag/rag_self_correction.py @@ -7,7 +7,7 @@ from canals.component.types import Variadic from haystack import Pipeline, Document, component, default_to_dict, default_from_dict, DeserializationError from haystack.document_stores import InMemoryDocumentStore from haystack.components.retrievers import InMemoryBM25Retriever -from haystack.components.generators import GPTGenerator +from haystack.components.generators import OpenAIGenerator from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.others import Multiplexer from haystack.components.routers.conditional_router import ConditionalRouter @@ -99,7 +99,7 @@ def self_correcting_pipeline(): ), name="prompt_builder", ) - rag_pipeline.add_component(instance=GPTGenerator(), name="llm") + rag_pipeline.add_component(instance=OpenAIGenerator(), name="llm") rag_pipeline.add_component( instance=ConditionalRouter( routes=[ diff --git a/examples/retrievers/in_memory_bm25_rag.py b/examples/retrievers/in_memory_bm25_rag.py index b0a0b33a5..75238270c 100644 --- a/examples/retrievers/in_memory_bm25_rag.py +++ b/examples/retrievers/in_memory_bm25_rag.py @@ -1,10 +1,11 @@ import os +from pathlib import Path from haystack import Document from haystack import Pipeline from haystack.components.builders.answer_builder import AnswerBuilder from haystack.components.builders.prompt_builder import PromptBuilder -from haystack.components.generators import GPTGenerator +from haystack.components.generators import OpenAIGenerator from haystack.components.retrievers import InMemoryBM25Retriever from haystack.document_stores import InMemoryDocumentStore @@ -22,7 +23,7 @@ prompt_template = """ rag_pipeline = Pipeline() rag_pipeline.add_component(instance=InMemoryBM25Retriever(document_store=InMemoryDocumentStore()), name="retriever") rag_pipeline.add_component(instance=PromptBuilder(template=prompt_template), name="prompt_builder") -rag_pipeline.add_component(instance=GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") +rag_pipeline.add_component(instance=OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")), name="llm") rag_pipeline.add_component(instance=AnswerBuilder(), name="answer_builder") rag_pipeline.connect("retriever", "prompt_builder.documents") rag_pipeline.connect("prompt_builder", "llm") @@ -31,7 +32,7 @@ rag_pipeline.connect("llm.metadata", "answer_builder.metadata") rag_pipeline.connect("retriever", "answer_builder.documents") # Draw the pipeline -rag_pipeline.draw("./rag_pipeline.png") +rag_pipeline.draw(Path("./rag_pipeline.png")) # Add Documents documents = [ @@ -43,7 +44,7 @@ documents = [ content="In certain parts of the world, like the Maldives, Puerto Rico, and San Diego, you can witness the phenomenon of bioluminescent waves." ), ] -rag_pipeline.get_component("retriever").document_store.write_documents(documents) +rag_pipeline.get_component("retriever").document_store.write_documents(documents) # type: ignore # Run the pipeline question = "How many languages are there?" diff --git a/haystack/components/builders/dynamic_prompt_builder.py b/haystack/components/builders/dynamic_prompt_builder.py index 59b025009..9ab025eb8 100644 --- a/haystack/components/builders/dynamic_prompt_builder.py +++ b/haystack/components/builders/dynamic_prompt_builder.py @@ -30,13 +30,13 @@ class DynamicPromptBuilder: ```python from haystack.components.builders import DynamicPromptBuilder - from haystack.components.generators.chat import GPTChatGenerator + from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage from haystack import Pipeline # no parameter init, we don't use any runtime template variables prompt_builder = DynamicPromptBuilder() - llm = GPTChatGenerator(api_key="", model_name="gpt-3.5-turbo") + llm = OpenAIChatGenerator(api_key="", model_name="gpt-3.5-turbo") pipe = Pipeline() pipe.add_component("prompt_builder", prompt_builder) @@ -88,14 +88,14 @@ class DynamicPromptBuilder: ```python from haystack.components.builders import DynamicPromptBuilder - from haystack.components.generators.chat import GPTChatGenerator + from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, Document from haystack import Pipeline, component from typing import List # we'll use documents runtime variable in our template, so we need to specify it in the init prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"]) - llm = GPTChatGenerator(api_key="", model_name="gpt-3.5-turbo") + llm = OpenAIChatGenerator(api_key="", model_name="gpt-3.5-turbo") @component @@ -135,7 +135,7 @@ class DynamicPromptBuilder: ```python prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"], chat_mode=False) - llm = GPTGenerator(api_key="", model_name="gpt-3.5-turbo") + llm = OpenAIGenerator(api_key="", model_name="gpt-3.5-turbo") @component diff --git a/haystack/components/generators/__init__.py b/haystack/components/generators/__init__.py index d018ff63a..a9251b933 100644 --- a/haystack/components/generators/__init__.py +++ b/haystack/components/generators/__init__.py @@ -1,5 +1,5 @@ from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator from haystack.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator -from haystack.components.generators.openai import GPTGenerator +from haystack.components.generators.openai import OpenAIGenerator, GPTGenerator -__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "GPTGenerator"] +__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "OpenAIGenerator", "GPTGenerator"] diff --git a/haystack/components/generators/chat/__init__.py b/haystack/components/generators/chat/__init__.py index b28648b5c..3227e50bf 100644 --- a/haystack/components/generators/chat/__init__.py +++ b/haystack/components/generators/chat/__init__.py @@ -1,4 +1,4 @@ from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator -from haystack.components.generators.chat.openai import GPTChatGenerator +from haystack.components.generators.chat.openai import OpenAIChatGenerator, GPTChatGenerator -__all__ = ["HuggingFaceTGIChatGenerator", "GPTChatGenerator"] +__all__ = ["HuggingFaceTGIChatGenerator", "OpenAIChatGenerator", "GPTChatGenerator"] diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index eb4e53118..30dde2967 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -1,6 +1,7 @@ import dataclasses import json import logging +import warnings from typing import Optional, List, Callable, Dict, Any, Union from openai import OpenAI, Stream @@ -14,7 +15,7 @@ logger = logging.getLogger(__name__) @component -class GPTChatGenerator: +class OpenAIChatGenerator: """ Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo family of models accessed through the chat completions API endpoint. @@ -27,12 +28,12 @@ class GPTChatGenerator: [documentation](https://platform.openai.com/docs/api-reference/chat). ```python - from haystack.components.generators.chat import GPTChatGenerator + from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage messages = [ChatMessage.from_user("What's Natural Language Processing?")] - client = GPTChatGenerator() + client = OpenAIChatGenerator() response = client.run(messages) print(response) @@ -66,7 +67,7 @@ class GPTChatGenerator: generation_kwargs: Optional[Dict[str, Any]] = None, ): """ - Creates an instance of ChatGPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's + Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's GPT-3.5 model. :param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the @@ -126,7 +127,7 @@ class GPTChatGenerator: ) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GPTChatGenerator": + def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator": """ Deserialize this component from a dictionary. :param data: The dictionary representation of this component. @@ -277,3 +278,29 @@ class GPTChatGenerator: logger.warning( "The completion for index %s has been truncated due to the content filter.", message.meta["index"] ) + + +class GPTChatGenerator(OpenAIChatGenerator): + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gpt-3.5-turbo", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + api_base_url: Optional[str] = None, + organization: Optional[str] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + warnings.warn( + "GPTChatGenerator is deprecated and will be removed in the next beta release. " + "Please use OpenAIChatGenerator instead.", + UserWarning, + stacklevel=2, + ) + super().__init__( + api_key=api_key, + model_name=model_name, + streaming_callback=streaming_callback, + api_base_url=api_base_url, + organization=organization, + generation_kwargs=generation_kwargs, + ) diff --git a/haystack/components/generators/openai.py b/haystack/components/generators/openai.py index 53a7d658c..8fe454ac3 100644 --- a/haystack/components/generators/openai.py +++ b/haystack/components/generators/openai.py @@ -1,6 +1,7 @@ import dataclasses import json import logging +import warnings from typing import Optional, List, Callable, Dict, Any, Union from openai import OpenAI, Stream @@ -14,7 +15,7 @@ logger = logging.getLogger(__name__) @component -class GPTGenerator: +class OpenAIGenerator: """ Enables text generation using OpenAI's large language models (LLMs). It supports gpt-4 and gpt-3.5-turbo family of models. @@ -27,8 +28,8 @@ class GPTGenerator: [documentation](https://platform.openai.com/docs/api-reference/chat). ```python - from haystack.components.generators import GPTGenerator - client = GPTGenerator() + from haystack.components.generators import OpenAIGenerator + client = OpenAIGenerator() response = client.run("What's Natural Language Processing? Be brief.") print(response) @@ -59,7 +60,7 @@ class GPTGenerator: generation_kwargs: Optional[Dict[str, Any]] = None, ): """ - Creates an instance of GPTGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's + Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model_name`, this is for OpenAI's GPT-3.5 model. :param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the @@ -123,7 +124,7 @@ class GPTGenerator: ) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "GPTGenerator": + def from_dict(cls, data: Dict[str, Any]) -> "OpenAIGenerator": """ Deserialize this component from a dictionary. :param data: The dictionary representation of this component. @@ -279,3 +280,31 @@ class GPTGenerator: logger.warning( "The completion for index %s has been truncated due to the content filter.", message.meta["index"] ) + + +class GPTGenerator(OpenAIGenerator): + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gpt-3.5-turbo", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + api_base_url: Optional[str] = None, + organization: Optional[str] = None, + system_prompt: Optional[str] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + warnings.warn( + "GPTGenerator is deprecated and will be removed in the next beta release. " + "Please use OpenAIGenerator instead.", + UserWarning, + stacklevel=2, + ) + super().__init__( + api_key=api_key, + model_name=model_name, + streaming_callback=streaming_callback, + api_base_url=api_base_url, + organization=organization, + system_prompt=system_prompt, + generation_kwargs=generation_kwargs, + ) diff --git a/haystack/pipeline_utils/rag.py b/haystack/pipeline_utils/rag.py index 1c088efe8..34e21aba3 100644 --- a/haystack/pipeline_utils/rag.py +++ b/haystack/pipeline_utils/rag.py @@ -8,7 +8,7 @@ from haystack import Pipeline from haystack.components.builders.answer_builder import AnswerBuilder from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.embedders import SentenceTransformersTextEmbedder -from haystack.components.generators import GPTGenerator, HuggingFaceTGIGenerator +from haystack.components.generators import OpenAIGenerator, HuggingFaceTGIGenerator from haystack.components.retrievers import InMemoryEmbeddingRetriever from haystack.dataclasses import Answer from haystack.document_stores import InMemoryDocumentStore, DocumentStore @@ -179,13 +179,13 @@ class _GeneratorResolver(ABC): class _OpenAIResolved(_GeneratorResolver): """ - Resolves the OpenAI GPTGenerator. + Resolves the OpenAIGenerator. """ def resolve(self, model_key: str, api_key: str) -> Any: # does the model_key match the pattern OpenAI GPT pattern? if re.match(r"^gpt-4-.*", model_key) or re.match(r"^gpt-3.5-.*", model_key): - return GPTGenerator(model_name=model_key, api_key=api_key) + return OpenAIGenerator(model_name=model_key, api_key=api_key) return None diff --git a/releasenotes/notes/rename-gpt-generators-f25011d251fafd6d.yaml b/releasenotes/notes/rename-gpt-generators-f25011d251fafd6d.yaml new file mode 100644 index 000000000..72f40559f --- /dev/null +++ b/releasenotes/notes/rename-gpt-generators-f25011d251fafd6d.yaml @@ -0,0 +1,4 @@ +--- +deprecations: + - | + Deprecate GPTGenerator and GPTChatGenerator. Replace them with OpenAIGenerator and OpenAIChatGenerator. diff --git a/test/components/generators/chat/test_openai.py b/test/components/generators/chat/test_openai.py index 03f36fd45..39668d086 100644 --- a/test/components/generators/chat/test_openai.py +++ b/test/components/generators/chat/test_openai.py @@ -3,7 +3,7 @@ import os import pytest from openai import OpenAIError -from haystack.components.generators.chat import GPTChatGenerator +from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.generators.utils import default_streaming_callback from haystack.dataclasses import ChatMessage, StreamingChunk @@ -16,9 +16,9 @@ def chat_messages(): ] -class TestGPTChatGenerator: +class TestOpenAIChatGenerator: def test_init_default(self): - component = GPTChatGenerator(api_key="test-api-key") + component = OpenAIChatGenerator(api_key="test-api-key") assert component.client.api_key == "test-api-key" assert component.model_name == "gpt-3.5-turbo" assert component.streaming_callback is None @@ -27,10 +27,10 @@ class TestGPTChatGenerator: def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) with pytest.raises(OpenAIError): - GPTChatGenerator() + OpenAIChatGenerator() def test_init_with_parameters(self): - component = GPTChatGenerator( + component = OpenAIChatGenerator( api_key="test-api-key", model_name="gpt-4", streaming_callback=default_streaming_callback, @@ -43,10 +43,10 @@ class TestGPTChatGenerator: assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self): - component = GPTChatGenerator(api_key="test-api-key") + component = OpenAIChatGenerator(api_key="test-api-key") data = component.to_dict() assert data == { - "type": "haystack.components.generators.chat.openai.GPTChatGenerator", + "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "model_name": "gpt-3.5-turbo", "organization": None, @@ -57,7 +57,7 @@ class TestGPTChatGenerator: } def test_to_dict_with_parameters(self): - component = GPTChatGenerator( + component = OpenAIChatGenerator( api_key="test-api-key", model_name="gpt-4", streaming_callback=default_streaming_callback, @@ -66,7 +66,7 @@ class TestGPTChatGenerator: ) data = component.to_dict() assert data == { - "type": "haystack.components.generators.chat.openai.GPTChatGenerator", + "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "model_name": "gpt-4", "organization": None, @@ -77,7 +77,7 @@ class TestGPTChatGenerator: } def test_to_dict_with_lambda_streaming_callback(self): - component = GPTChatGenerator( + component = OpenAIChatGenerator( api_key="test-api-key", model_name="gpt-4", streaming_callback=lambda x: x, @@ -86,7 +86,7 @@ class TestGPTChatGenerator: ) data = component.to_dict() assert data == { - "type": "haystack.components.generators.chat.openai.GPTChatGenerator", + "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "model_name": "gpt-4", "organization": None, @@ -98,7 +98,7 @@ class TestGPTChatGenerator: def test_from_dict(self): data = { - "type": "haystack.components.generators.chat.openai.GPTChatGenerator", + "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "model_name": "gpt-4", "api_base_url": "test-base-url", @@ -106,7 +106,7 @@ class TestGPTChatGenerator: "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } - component = GPTChatGenerator.from_dict(data) + component = OpenAIChatGenerator.from_dict(data) assert component.model_name == "gpt-4" assert component.streaming_callback is default_streaming_callback assert component.api_base_url == "test-base-url" @@ -115,7 +115,7 @@ class TestGPTChatGenerator: def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) data = { - "type": "haystack.components.generators.chat.openai.GPTChatGenerator", + "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "init_parameters": { "model_name": "gpt-4", "organization": None, @@ -125,10 +125,10 @@ class TestGPTChatGenerator: }, } with pytest.raises(OpenAIError): - GPTChatGenerator.from_dict(data) + OpenAIChatGenerator.from_dict(data) def test_run(self, chat_messages, mock_chat_completion): - component = GPTChatGenerator() + component = OpenAIChatGenerator() response = component.run(chat_messages) # check that the component returns the correct ChatMessage response @@ -139,7 +139,7 @@ class TestGPTChatGenerator: assert [isinstance(reply, ChatMessage) for reply in response["replies"]] def test_run_with_params(self, chat_messages, mock_chat_completion): - component = GPTChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5}) + component = OpenAIChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5}) response = component.run(chat_messages) # check that the component calls the OpenAI API with the correct parameters @@ -161,7 +161,7 @@ class TestGPTChatGenerator: nonlocal streaming_callback_called streaming_callback_called = True - component = GPTChatGenerator(streaming_callback=streaming_callback) + component = OpenAIChatGenerator(streaming_callback=streaming_callback) response = component.run(chat_messages) # check we called the streaming callback @@ -176,7 +176,7 @@ class TestGPTChatGenerator: assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk def test_check_abnormal_completions(self, caplog): - component = GPTChatGenerator(api_key="test-api-key") + component = OpenAIChatGenerator(api_key="test-api-key") messages = [ ChatMessage.from_assistant( "", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} @@ -208,7 +208,7 @@ class TestGPTChatGenerator: @pytest.mark.integration def test_live_run(self): chat_messages = [ChatMessage.from_user("What's the capital of France")] - component = GPTChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1}) + component = OpenAIChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1}) results = component.run(chat_messages) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] @@ -222,7 +222,9 @@ class TestGPTChatGenerator: ) @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): - component = GPTChatGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")) + component = OpenAIChatGenerator( + model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY") + ) with pytest.raises(OpenAIError): component.run(chat_messages) @@ -242,7 +244,7 @@ class TestGPTChatGenerator: self.responses += chunk.content if chunk.content else "" callback = Callback() - component = GPTChatGenerator(streaming_callback=callback) + component = OpenAIChatGenerator(streaming_callback=callback) results = component.run([ChatMessage.from_user("What's the capital of France?")]) assert len(results["replies"]) == 1 diff --git a/test/components/generators/test_openai.py b/test/components/generators/test_openai.py index 755d7c867..7832cfb5e 100644 --- a/test/components/generators/test_openai.py +++ b/test/components/generators/test_openai.py @@ -4,14 +4,14 @@ from typing import List import pytest from openai import OpenAIError -from haystack.components.generators import GPTGenerator +from haystack.components.generators import OpenAIGenerator from haystack.components.generators.utils import default_streaming_callback from haystack.dataclasses import StreamingChunk, ChatMessage -class TestGPTGenerator: +class TestOpenAIGenerator: def test_init_default(self): - component = GPTGenerator(api_key="test-api-key") + component = OpenAIGenerator(api_key="test-api-key") assert component.client.api_key == "test-api-key" assert component.model_name == "gpt-3.5-turbo" assert component.streaming_callback is None @@ -20,10 +20,10 @@ class TestGPTGenerator: def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) with pytest.raises(OpenAIError): - GPTGenerator() + OpenAIGenerator() def test_init_with_parameters(self): - component = GPTGenerator( + component = OpenAIGenerator( api_key="test-api-key", model_name="gpt-4", streaming_callback=default_streaming_callback, @@ -36,10 +36,10 @@ class TestGPTGenerator: assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self): - component = GPTGenerator(api_key="test-api-key") + component = OpenAIGenerator(api_key="test-api-key") data = component.to_dict() assert data == { - "type": "haystack.components.generators.openai.GPTGenerator", + "type": "haystack.components.generators.openai.OpenAIGenerator", "init_parameters": { "model_name": "gpt-3.5-turbo", "streaming_callback": None, @@ -50,7 +50,7 @@ class TestGPTGenerator: } def test_to_dict_with_parameters(self): - component = GPTGenerator( + component = OpenAIGenerator( api_key="test-api-key", model_name="gpt-4", streaming_callback=default_streaming_callback, @@ -59,7 +59,7 @@ class TestGPTGenerator: ) data = component.to_dict() assert data == { - "type": "haystack.components.generators.openai.GPTGenerator", + "type": "haystack.components.generators.openai.OpenAIGenerator", "init_parameters": { "model_name": "gpt-4", "system_prompt": None, @@ -70,7 +70,7 @@ class TestGPTGenerator: } def test_to_dict_with_lambda_streaming_callback(self): - component = GPTGenerator( + component = OpenAIGenerator( api_key="test-api-key", model_name="gpt-4", streaming_callback=lambda x: x, @@ -79,7 +79,7 @@ class TestGPTGenerator: ) data = component.to_dict() assert data == { - "type": "haystack.components.generators.openai.GPTGenerator", + "type": "haystack.components.generators.openai.OpenAIGenerator", "init_parameters": { "model_name": "gpt-4", "system_prompt": None, @@ -92,7 +92,7 @@ class TestGPTGenerator: def test_from_dict(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key") data = { - "type": "haystack.components.generators.openai.GPTGenerator", + "type": "haystack.components.generators.openai.OpenAIGenerator", "init_parameters": { "model_name": "gpt-4", "system_prompt": None, @@ -101,7 +101,7 @@ class TestGPTGenerator: "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, } - component = GPTGenerator.from_dict(data) + component = OpenAIGenerator.from_dict(data) assert component.model_name == "gpt-4" assert component.streaming_callback is default_streaming_callback assert component.api_base_url == "test-base-url" @@ -110,7 +110,7 @@ class TestGPTGenerator: def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("OPENAI_API_KEY", raising=False) data = { - "type": "haystack.components.generators.openai.GPTGenerator", + "type": "haystack.components.generators.openai.OpenAIGenerator", "init_parameters": { "model_name": "gpt-4", "api_base_url": "test-base-url", @@ -119,10 +119,10 @@ class TestGPTGenerator: }, } with pytest.raises(OpenAIError): - GPTGenerator.from_dict(data) + OpenAIGenerator.from_dict(data) def test_run(self, mock_chat_completion): - component = GPTGenerator(api_key="test-api-key") + component = OpenAIGenerator(api_key="test-api-key") response = component.run("What's Natural Language Processing?") # check that the component returns the correct ChatMessage response @@ -139,7 +139,7 @@ class TestGPTGenerator: nonlocal streaming_callback_called streaming_callback_called = True - component = GPTGenerator(streaming_callback=streaming_callback) + component = OpenAIGenerator(streaming_callback=streaming_callback) response = component.run("Come on, stream!") # check we called the streaming callback @@ -153,7 +153,7 @@ class TestGPTGenerator: assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk def test_run_with_params(self, mock_chat_completion): - component = GPTGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5}) + component = OpenAIGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5}) response = component.run("What's Natural Language Processing?") # check that the component calls the OpenAI API with the correct parameters @@ -169,7 +169,7 @@ class TestGPTGenerator: assert [isinstance(reply, str) for reply in response["replies"]] def test_check_abnormal_completions(self, caplog): - component = GPTGenerator(api_key="test-api-key") + component = OpenAIGenerator(api_key="test-api-key") # underlying implementation uses ChatMessage objects so we have to use them here messages: List[ChatMessage] = [] @@ -202,7 +202,7 @@ class TestGPTGenerator: ) @pytest.mark.integration def test_live_run(self): - component = GPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")) + component = OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")) results = component.run("What's the capital of France?") assert len(results["replies"]) == 1 assert len(results["meta"]) == 1 @@ -224,7 +224,7 @@ class TestGPTGenerator: ) @pytest.mark.integration def test_live_run_wrong_model(self): - component = GPTGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")) + component = OpenAIGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")) with pytest.raises(OpenAIError): component.run("Whatever") @@ -244,7 +244,7 @@ class TestGPTGenerator: self.responses += chunk.content if chunk.content else "" callback = Callback() - component = GPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback) + component = OpenAIGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback) results = component.run("What's the capital of France?") assert len(results["replies"]) == 1