From 96b9d3e32abdd4d7cf6280fea26570bf55158c30 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Wed, 15 May 2024 10:00:38 +0200 Subject: [PATCH] fix: Adding missing `component` decorator to AzureOpenAIGenerator (#7698) * initial import * adding release notes * tests avoiding I/O operations * Update fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml --- haystack/components/generators/azure.py | 3 ++- haystack/components/generators/chat/azure.py | 3 ++- ...ure-generators-serialization-18fcdc9cbcb3732e.yaml | 4 ++++ test/components/generators/chat/test_azure.py | 10 ++++++++++ test/components/generators/test_azure.py | 11 +++++++++++ 5 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 releasenotes/notes/fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml diff --git a/haystack/components/generators/azure.py b/haystack/components/generators/azure.py index 388b9eb99..2c432d823 100644 --- a/haystack/components/generators/azure.py +++ b/haystack/components/generators/azure.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional # pylint: disable=import-error from openai.lib.azure import AzureOpenAI -from haystack import default_from_dict, default_to_dict, logging +from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators import OpenAIGenerator from haystack.dataclasses import StreamingChunk from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable @@ -16,6 +16,7 @@ from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inp logger = logging.getLogger(__name__) +@component class AzureOpenAIGenerator(OpenAIGenerator): """ A Generator component that uses OpenAI's large language models (LLMs) on Azure to generate text. diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index 189b0c949..b6cd2e153 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional # pylint: disable=import-error from openai.lib.azure import AzureOpenAI -from haystack import default_from_dict, default_to_dict, logging +from haystack import component, default_from_dict, default_to_dict, logging from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import StreamingChunk from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable @@ -16,6 +16,7 @@ from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inp logger = logging.getLogger(__name__) +@component class AzureOpenAIChatGenerator(OpenAIChatGenerator): """ A Chat Generator component that uses the Azure OpenAI API to generate text. diff --git a/releasenotes/notes/fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml b/releasenotes/notes/fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml new file mode 100644 index 000000000..071969e62 --- /dev/null +++ b/releasenotes/notes/fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Azure generators components fixed, they were missing the `@component` decorator. diff --git a/test/components/generators/chat/test_azure.py b/test/components/generators/chat/test_azure.py index 4b92fdee5..c9693caac 100644 --- a/test/components/generators/chat/test_azure.py +++ b/test/components/generators/chat/test_azure.py @@ -6,6 +6,7 @@ import os import pytest from openai import OpenAIError +from haystack import Pipeline from haystack.components.generators.chat import AzureOpenAIChatGenerator from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage @@ -80,6 +81,15 @@ class TestOpenAIChatGenerator: }, } + def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") + p = Pipeline() + p.add_component(instance=generator, name="generator") + p_str = p.dumps() + q = Pipeline.loads(p_str) + assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed." + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None), diff --git a/test/components/generators/test_azure.py b/test/components/generators/test_azure.py index b2373eb68..d5d52b6f2 100644 --- a/test/components/generators/test_azure.py +++ b/test/components/generators/test_azure.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 import os + +from haystack import Pipeline from haystack.utils.auth import Secret import pytest @@ -83,6 +85,15 @@ class TestAzureOpenAIGenerator: }, } + def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch): + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + generator = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint") + p = Pipeline() + p.add_component(instance=generator, name="generator") + p_str = p.dumps() + q = Pipeline.loads(p_str) + assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with AzureOpenAIGenerator failed." + @pytest.mark.integration @pytest.mark.skipif( not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),