mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
fix: Adding missing component decorator to AzureOpenAIGenerator (#7698)
* initial import * adding release notes * tests avoiding I/O operations * Update fix-azure-generators-serialization-18fcdc9cbcb3732e.yaml
This commit is contained in:
parent
cc1d4b1c80
commit
96b9d3e32a
@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional
|
||||
# pylint: disable=import-error
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from haystack import default_from_dict, default_to_dict, logging
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
@ -16,6 +16,7 @@ from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inp
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
"""
|
||||
A Generator component that uses OpenAI's large language models (LLMs) on Azure to generate text.
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional
|
||||
# pylint: disable=import-error
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from haystack import default_from_dict, default_to_dict, logging
|
||||
from haystack import component, default_from_dict, default_to_dict, logging
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
|
||||
@ -16,6 +16,7 @@ from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inp
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@component
|
||||
class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
"""
|
||||
A Chat Generator component that uses the Azure OpenAI API to generate text.
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Azure generators components fixed, they were missing the `@component` decorator.
|
||||
@ -6,6 +6,7 @@ import os
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.components.generators.chat import AzureOpenAIChatGenerator
|
||||
from haystack.components.generators.utils import print_streaming_chunk
|
||||
from haystack.dataclasses import ChatMessage
|
||||
@ -80,6 +81,15 @@ class TestOpenAIChatGenerator:
|
||||
},
|
||||
}
|
||||
|
||||
def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
generator = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
|
||||
p = Pipeline()
|
||||
p.add_component(instance=generator, name="generator")
|
||||
p_str = p.dumps()
|
||||
q = Pipeline.loads(p_str)
|
||||
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization w/ AzureOpenAIChatGenerator failed."
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
@ -83,6 +85,15 @@ class TestAzureOpenAIGenerator:
|
||||
},
|
||||
}
|
||||
|
||||
def test_pipeline_serialization_deserialization(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
generator = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
|
||||
p = Pipeline()
|
||||
p.add_component(instance=generator, name="generator")
|
||||
p_str = p.dumps()
|
||||
q = Pipeline.loads(p_str)
|
||||
assert p.to_dict() == q.to_dict(), "Pipeline serialization/deserialization with AzureOpenAIGenerator failed."
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user