fix: Adding missing component decorator to AzureOpenAIGenerator (#7698)

* initial import

* adding release notes

* tests avoiding I/O operations

* Update fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml
This commit is contained in:
David S. Batista 2024-05-15 10:00:38 +02:00 committed by GitHub
parent cc1d4b1c80
commit 96b9d3e32a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 29 additions and 2 deletions

View File

@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional
# pylint: disable=import-error
from openai.lib.azure import AzureOpenAI
from haystack import default_from_dict, default_to_dict, logging
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators import OpenAIGenerator
from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
@ -16,6 +16,7 @@ from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inp
logger = logging.getLogger(__name__)
@component
class AzureOpenAIGenerator(OpenAIGenerator):
"""
A Generator component that uses OpenAI's large language models (LLMs) on Azure to generate text.

View File

@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional
# pylint: disable=import-error
from openai.lib.azure import AzureOpenAI
from haystack import default_from_dict, default_to_dict, logging
from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
@ -16,6 +16,7 @@ from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inp
logger = logging.getLogger(__name__)
@component
class AzureOpenAIChatGenerator(OpenAIChatGenerator):
"""
A Chat Generator component that uses the Azure OpenAI API to generate text.

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Azure generators components fixed, they were missing the `@component` decorator.

View File

@ -6,6 +6,7 @@ import os
import pytest
from openai import OpenAIError
from haystack import Pipeline
from haystack.components.generators.chat import AzureOpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage
@ -80,6 +81,15 @@ class TestOpenAIChatGenerator:
},
}
def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
p = Pipeline()
p.add_component(instance=generator, name="generator")
p_str = p.dumps()
q = Pipeline.loads(p_str)
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

View File

@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
from haystack import Pipeline
from haystack.utils.auth import Secret
import pytest
@ -83,6 +85,15 @@ class TestAzureOpenAIGenerator:
},
}
def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
generator = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
p = Pipeline()
p.add_component(instance=generator, name="generator")
p_str = p.dumps()
q = Pipeline.loads(p_str)
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with AzureOpenAIGenerator failed."
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),