mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-20 04:18:57 +00:00
feat: Add AzureOpenAIGenerator and AzureOpenAIChatGenerator (#6648)
* Add AzureOpenAIGenerator and AzureOpenAIChatGenerator
This commit is contained in:
parent
9c08f3d9c7
commit
b7159ad7c2
@ -1,5 +1,12 @@
|
|||||||
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
|
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator
|
||||||
from haystack.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
|
from haystack.components.generators.hugging_face_tgi import HuggingFaceTGIGenerator
|
||||||
from haystack.components.generators.openai import OpenAIGenerator, GPTGenerator
|
from haystack.components.generators.openai import OpenAIGenerator, GPTGenerator
|
||||||
|
from haystack.components.generators.azure import AzureOpenAIGenerator
|
||||||
|
|
||||||
__all__ = ["HuggingFaceLocalGenerator", "HuggingFaceTGIGenerator", "OpenAIGenerator", "GPTGenerator"]
|
__all__ = [
|
||||||
|
"HuggingFaceLocalGenerator",
|
||||||
|
"HuggingFaceTGIGenerator",
|
||||||
|
"OpenAIGenerator",
|
||||||
|
"GPTGenerator",
|
||||||
|
"AzureOpenAIGenerator",
|
||||||
|
]
|
||||||
|
158
haystack/components/generators/azure.py
Normal file
158
haystack/components/generators/azure.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional, Callable, Dict, Any
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
||||||
|
|
||||||
|
from haystack import default_to_dict, default_from_dict
|
||||||
|
from haystack.components.generators import OpenAIGenerator
|
||||||
|
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||||
|
from haystack.dataclasses import StreamingChunk
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIGenerator(OpenAIGenerator):
|
||||||
|
"""
|
||||||
|
Enables text generation using OpenAI's large language models (LLMs) on Azure. It supports gpt-4 and gpt-3.5-turbo
|
||||||
|
family of models.
|
||||||
|
|
||||||
|
Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method
|
||||||
|
directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs`
|
||||||
|
parameter in `run` method.
|
||||||
|
|
||||||
|
For more details on OpenAI models deployed on Azure, refer to the Microsoft
|
||||||
|
[documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/).
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
from haystack.components.generators import AzureOpenAIGenerator
|
||||||
|
client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
|
||||||
|
api_key="<you api key>",
|
||||||
|
azure_deployment="<this a model name, e.g. gpt-35-turbo>")
|
||||||
|
response = client.run("What's Natural Language Processing? Be brief.")
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
>> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
|
||||||
|
>> the interaction between computers and human language. It involves enabling computers to understand, interpret,
|
||||||
|
>> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{'model':
|
||||||
|
>> 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 16,
|
||||||
|
>> 'completion_tokens': 49, 'total_tokens': 65}}]}
|
||||||
|
```
|
||||||
|
|
||||||
|
Key Features and Compatibility:
|
||||||
|
- **Primary Compatibility**: Designed to work seamlessly with gpt-4, gpt-3.5-turbo family of models.
|
||||||
|
- **Streaming Support**: Supports streaming responses from the OpenAI API.
|
||||||
|
- **Customizability**: Supports all parameters supported by the OpenAI API.
|
||||||
|
|
||||||
|
Input and Output Format:
|
||||||
|
- **String Format**: This component uses the strings for both input and output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=super-init-not-called
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
azure_endpoint: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = "2023-05-15",
|
||||||
|
azure_deployment: Optional[str] = "gpt-35-turbo",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
azure_ad_token: Optional[str] = None,
|
||||||
|
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/`
|
||||||
|
:param api_version: The version of the API to use. Defaults to 2023-05-15
|
||||||
|
:param azure_deployment: The deployment of the model, usually the model name.
|
||||||
|
:param api_key: The API key to use for authentication.
|
||||||
|
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||||
|
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
|
||||||
|
on every request.
|
||||||
|
:param organization: The Organization ID, defaults to `None`. See
|
||||||
|
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
|
||||||
|
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||||
|
The callback function accepts StreamingChunk as an argument.
|
||||||
|
:param system_prompt: The prompt to use for the system. If not provided, the system prompt will be
|
||||||
|
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
|
||||||
|
the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
|
||||||
|
more details.
|
||||||
|
Some of the supported parameters:
|
||||||
|
- `max_tokens`: The maximum number of tokens the output text can have.
|
||||||
|
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
|
||||||
|
Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
|
||||||
|
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
|
||||||
|
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
|
||||||
|
comprising the top 10% probability mass are considered.
|
||||||
|
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
|
||||||
|
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
|
||||||
|
- `stop`: One or more sequences after which the LLM should stop generating tokens.
|
||||||
|
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
|
||||||
|
the model will be less likely to repeat the same token in the text.
|
||||||
|
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
|
||||||
|
Bigger values mean the model will be less likely to repeat the same token in the text.
|
||||||
|
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
|
||||||
|
values are the bias to add to that token.
|
||||||
|
"""
|
||||||
|
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
|
||||||
|
# with the API.
|
||||||
|
|
||||||
|
# Why is this here?
|
||||||
|
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
|
||||||
|
# None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead
|
||||||
|
# of passing it as a parameter.
|
||||||
|
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||||
|
if not azure_endpoint:
|
||||||
|
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
|
||||||
|
|
||||||
|
self.generation_kwargs = generation_kwargs or {}
|
||||||
|
self.system_prompt = system_prompt
|
||||||
|
self.streaming_callback = streaming_callback
|
||||||
|
self.api_version = api_version
|
||||||
|
self.azure_endpoint = azure_endpoint
|
||||||
|
self.azure_deployment = azure_deployment
|
||||||
|
self.organization = organization
|
||||||
|
self.model_name: str = azure_deployment or "gpt-35-turbo"
|
||||||
|
|
||||||
|
self.client = AzureOpenAI(
|
||||||
|
api_version=api_version,
|
||||||
|
azure_endpoint=azure_endpoint,
|
||||||
|
azure_deployment=azure_deployment,
|
||||||
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
|
organization=organization,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serialize this component to a dictionary.
|
||||||
|
:return: The serialized component as a dictionary.
|
||||||
|
"""
|
||||||
|
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
|
||||||
|
return default_to_dict(
|
||||||
|
self,
|
||||||
|
azure_endpoint=self.azure_endpoint,
|
||||||
|
azure_deployment=self.azure_deployment,
|
||||||
|
organization=self.organization,
|
||||||
|
api_version=self.api_version,
|
||||||
|
streaming_callback=callback_name,
|
||||||
|
generation_kwargs=self.generation_kwargs,
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIGenerator":
|
||||||
|
"""
|
||||||
|
Deserialize this component from a dictionary.
|
||||||
|
:param data: The dictionary representation of this component.
|
||||||
|
:return: The deserialized component instance.
|
||||||
|
"""
|
||||||
|
init_params = data.get("init_parameters", {})
|
||||||
|
serialized_callback_handler = init_params.get("streaming_callback")
|
||||||
|
if serialized_callback_handler:
|
||||||
|
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
|
||||||
|
return default_from_dict(cls, data)
|
@ -1,4 +1,5 @@
|
|||||||
from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
|
from haystack.components.generators.chat.hugging_face_tgi import HuggingFaceTGIChatGenerator
|
||||||
from haystack.components.generators.chat.openai import OpenAIChatGenerator, GPTChatGenerator
|
from haystack.components.generators.chat.openai import OpenAIChatGenerator, GPTChatGenerator
|
||||||
|
from haystack.components.generators.chat.azure import AzureOpenAIChatGenerator
|
||||||
|
|
||||||
__all__ = ["HuggingFaceTGIChatGenerator", "OpenAIChatGenerator", "GPTChatGenerator"]
|
__all__ = ["HuggingFaceTGIChatGenerator", "OpenAIChatGenerator", "GPTChatGenerator", "AzureOpenAIChatGenerator"]
|
||||||
|
161
haystack/components/generators/chat/azure.py
Normal file
161
haystack/components/generators/chat/azure.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional, Callable, Dict, Any
|
||||||
|
|
||||||
|
# pylint: disable=import-error
|
||||||
|
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
||||||
|
|
||||||
|
from haystack import default_to_dict, default_from_dict
|
||||||
|
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||||
|
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||||
|
from haystack.dataclasses import StreamingChunk
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||||
|
"""
|
||||||
|
Enables text generation using OpenAI's large language models (LLMs) on Azure. It supports gpt-4 and gpt-3.5-turbo
|
||||||
|
family of models accessed through the chat completions API endpoint.
|
||||||
|
|
||||||
|
Users can pass any text generation parameters valid for the `openai.ChatCompletion.create` method
|
||||||
|
directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs`
|
||||||
|
parameter in `run` method.
|
||||||
|
|
||||||
|
For more details on OpenAI models deployed on Azure, refer to the Microsoft
|
||||||
|
[documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from haystack.components.generators.chat import AzureOpenAIGenerator
|
||||||
|
from haystack.dataclasses import ChatMessage
|
||||||
|
|
||||||
|
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
|
||||||
|
|
||||||
|
client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
|
||||||
|
api_key="<you api key>",
|
||||||
|
azure_deployment="<this a model name, e.g. gpt-35-turbo>")
|
||||||
|
response = client.run(messages)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
>>{'replies': [ChatMessage(content='Natural Language Processing (NLP) is a branch of artificial intelligence
|
||||||
|
>>that focuses on enabling computers to understand, interpret, and generate human language in a way that is
|
||||||
|
>>meaningful and useful.', role=<ChatRole.ASSISTANT: 'assistant'>, name=None,
|
||||||
|
>>meta={'model': 'gpt-3.5-turbo-0613', 'index': 0, 'finish_reason': 'stop',
|
||||||
|
>>'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Key Features and Compatibility:
|
||||||
|
- **Primary Compatibility**: Designed to work seamlessly with the OpenAI API Chat Completion endpoint
|
||||||
|
and gpt-4 and gpt-3.5-turbo family of models.
|
||||||
|
- **Streaming Support**: Supports streaming responses from the OpenAI API Chat Completion endpoint.
|
||||||
|
- **Customizability**: Supports all parameters supported by the OpenAI API Chat Completion endpoint.
|
||||||
|
|
||||||
|
Input and Output Format:
|
||||||
|
- **ChatMessage Format**: This component uses the ChatMessage format for structuring both input and output,
|
||||||
|
ensuring coherent and contextually relevant responses in chat-based text generation scenarios. Details on the
|
||||||
|
ChatMessage format can be found at: https://github.com/openai/openai-python/blob/main/chatml.md.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=super-init-not-called
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
azure_endpoint: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = "2023-05-15",
|
||||||
|
azure_deployment: Optional[str] = "gpt-35-turbo",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
azure_ad_token: Optional[str] = None,
|
||||||
|
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||||
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/`
|
||||||
|
:param api_version: The version of the API to use. Defaults to 2023-05-15
|
||||||
|
:param azure_deployment: The deployment of the model, usually the model name.
|
||||||
|
:param api_key: The API key to use for authentication.
|
||||||
|
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||||
|
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
|
||||||
|
on every request.
|
||||||
|
:param organization: The Organization ID, defaults to `None`. See
|
||||||
|
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
|
||||||
|
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||||
|
The callback function accepts StreamingChunk as an argument.
|
||||||
|
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
|
||||||
|
the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
|
||||||
|
more details.
|
||||||
|
Some of the supported parameters:
|
||||||
|
- `max_tokens`: The maximum number of tokens the output text can have.
|
||||||
|
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
|
||||||
|
Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
|
||||||
|
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
|
||||||
|
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
|
||||||
|
comprising the top 10% probability mass are considered.
|
||||||
|
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
|
||||||
|
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
|
||||||
|
- `stop`: One or more sequences after which the LLM should stop generating tokens.
|
||||||
|
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
|
||||||
|
the model will be less likely to repeat the same token in the text.
|
||||||
|
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
|
||||||
|
Bigger values mean the model will be less likely to repeat the same token in the text.
|
||||||
|
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
|
||||||
|
values are the bias to add to that token.
|
||||||
|
"""
|
||||||
|
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
|
||||||
|
# with the API.
|
||||||
|
|
||||||
|
# Why is this here?
|
||||||
|
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
|
||||||
|
# None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead
|
||||||
|
# of passing it as a parameter.
|
||||||
|
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||||
|
if not azure_endpoint:
|
||||||
|
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
|
||||||
|
|
||||||
|
self.generation_kwargs = generation_kwargs or {}
|
||||||
|
self.streaming_callback = streaming_callback
|
||||||
|
self.api_version = api_version
|
||||||
|
self.azure_endpoint = azure_endpoint
|
||||||
|
self.azure_deployment = azure_deployment
|
||||||
|
self.organization = organization
|
||||||
|
self.model_name = azure_deployment or "gpt-35-turbo"
|
||||||
|
|
||||||
|
self.client = AzureOpenAI(
|
||||||
|
api_version=api_version,
|
||||||
|
azure_endpoint=azure_endpoint,
|
||||||
|
azure_deployment=azure_deployment,
|
||||||
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
|
organization=organization,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serialize this component to a dictionary.
|
||||||
|
:return: The serialized component as a dictionary.
|
||||||
|
"""
|
||||||
|
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
|
||||||
|
return default_to_dict(
|
||||||
|
self,
|
||||||
|
azure_endpoint=self.azure_endpoint,
|
||||||
|
azure_deployment=self.azure_deployment,
|
||||||
|
organization=self.organization,
|
||||||
|
api_version=self.api_version,
|
||||||
|
streaming_callback=callback_name,
|
||||||
|
generation_kwargs=self.generation_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIChatGenerator":
|
||||||
|
"""
|
||||||
|
Deserialize this component from a dictionary.
|
||||||
|
:param data: The dictionary representation of this component.
|
||||||
|
:return: The deserialized component instance.
|
||||||
|
"""
|
||||||
|
init_params = data.get("init_parameters", {})
|
||||||
|
serialized_callback_handler = init_params.get("streaming_callback")
|
||||||
|
if serialized_callback_handler:
|
||||||
|
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler)
|
||||||
|
return default_from_dict(cls, data)
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Adds support for Azure OpenAI models with AzureOpenAIGenerator and AzureOpenAIChatGenerator components.
|
89
test/components/generators/chat/test_azure.py
Normal file
89
test/components/generators/chat/test_azure.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai import OpenAIError
|
||||||
|
|
||||||
|
from haystack.components.generators.chat import AzureOpenAIChatGenerator
|
||||||
|
from haystack.components.generators.utils import default_streaming_callback
|
||||||
|
from haystack.dataclasses import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIChatGenerator:
|
||||||
|
def test_init_default(self):
|
||||||
|
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", api_key="test-api-key")
|
||||||
|
assert component.client.api_key == "test-api-key"
|
||||||
|
assert component.azure_deployment == "gpt-35-turbo"
|
||||||
|
assert component.streaming_callback is None
|
||||||
|
assert not component.generation_kwargs
|
||||||
|
|
||||||
|
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||||
|
monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False)
|
||||||
|
with pytest.raises(OpenAIError):
|
||||||
|
AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
|
||||||
|
|
||||||
|
def test_init_with_parameters(self):
|
||||||
|
component = AzureOpenAIChatGenerator(
|
||||||
|
azure_endpoint="some-non-existing-endpoint",
|
||||||
|
api_key="test-api-key",
|
||||||
|
streaming_callback=default_streaming_callback,
|
||||||
|
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||||
|
)
|
||||||
|
assert component.client.api_key == "test-api-key"
|
||||||
|
assert component.azure_deployment == "gpt-35-turbo"
|
||||||
|
assert component.streaming_callback is default_streaming_callback
|
||||||
|
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||||
|
|
||||||
|
def test_to_dict_default(self):
|
||||||
|
component = AzureOpenAIChatGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint")
|
||||||
|
data = component.to_dict()
|
||||||
|
assert data == {
|
||||||
|
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
|
||||||
|
"init_parameters": {
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
"azure_endpoint": "some-non-existing-endpoint",
|
||||||
|
"azure_deployment": "gpt-35-turbo",
|
||||||
|
"organization": None,
|
||||||
|
"streaming_callback": None,
|
||||||
|
"generation_kwargs": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_to_dict_with_parameters(self):
|
||||||
|
component = AzureOpenAIChatGenerator(
|
||||||
|
api_key="test-api-key",
|
||||||
|
azure_endpoint="some-non-existing-endpoint",
|
||||||
|
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||||
|
)
|
||||||
|
data = component.to_dict()
|
||||||
|
assert data == {
|
||||||
|
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
|
||||||
|
"init_parameters": {
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
"azure_endpoint": "some-non-existing-endpoint",
|
||||||
|
"azure_deployment": "gpt-35-turbo",
|
||||||
|
"organization": None,
|
||||||
|
"streaming_callback": None,
|
||||||
|
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||||
|
reason=(
|
||||||
|
"Please export env variables called AZURE_OPENAI_API_KEY containing "
|
||||||
|
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
|
||||||
|
"the Azure OpenAI endpoint URL to run this test."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_live_run(self):
|
||||||
|
chat_messages = [ChatMessage.from_user("What's the capital of France")]
|
||||||
|
component = AzureOpenAIChatGenerator(organization="HaystackCI")
|
||||||
|
results = component.run(chat_messages)
|
||||||
|
assert len(results["replies"]) == 1
|
||||||
|
message: ChatMessage = results["replies"][0]
|
||||||
|
assert "Paris" in message.content
|
||||||
|
assert "gpt-35-turbo" in message.meta["model"]
|
||||||
|
assert message.meta["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
# additional tests intentionally omitted as they are covered by test_openai.py
|
99
test/components/generators/test_azure.py
Normal file
99
test/components/generators/test_azure.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai import OpenAIError
|
||||||
|
|
||||||
|
from haystack.components.generators import AzureOpenAIGenerator
|
||||||
|
from haystack.components.generators.utils import default_streaming_callback
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureOpenAIGenerator:
|
||||||
|
def test_init_default(self):
|
||||||
|
component = AzureOpenAIGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint")
|
||||||
|
assert component.client.api_key == "test-api-key"
|
||||||
|
assert component.azure_deployment == "gpt-35-turbo"
|
||||||
|
assert component.streaming_callback is None
|
||||||
|
assert not component.generation_kwargs
|
||||||
|
|
||||||
|
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||||
|
monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False)
|
||||||
|
with pytest.raises(OpenAIError):
|
||||||
|
AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
|
||||||
|
|
||||||
|
def test_init_with_parameters(self):
|
||||||
|
component = AzureOpenAIGenerator(
|
||||||
|
api_key="test-api-key",
|
||||||
|
azure_endpoint="some-non-existing-endpoint",
|
||||||
|
azure_deployment="gpt-35-turbo",
|
||||||
|
streaming_callback=default_streaming_callback,
|
||||||
|
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||||
|
)
|
||||||
|
assert component.client.api_key == "test-api-key"
|
||||||
|
assert component.azure_deployment == "gpt-35-turbo"
|
||||||
|
assert component.streaming_callback is default_streaming_callback
|
||||||
|
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||||
|
|
||||||
|
def test_to_dict_default(self):
|
||||||
|
component = AzureOpenAIGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint")
|
||||||
|
data = component.to_dict()
|
||||||
|
assert data == {
|
||||||
|
"type": "haystack.components.generators.azure.AzureOpenAIGenerator",
|
||||||
|
"init_parameters": {
|
||||||
|
"azure_deployment": "gpt-35-turbo",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
"streaming_callback": None,
|
||||||
|
"azure_endpoint": "some-non-existing-endpoint",
|
||||||
|
"organization": None,
|
||||||
|
"system_prompt": None,
|
||||||
|
"generation_kwargs": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_to_dict_with_parameters(self):
|
||||||
|
component = AzureOpenAIGenerator(
|
||||||
|
api_key="test-api-key",
|
||||||
|
azure_endpoint="some-non-existing-endpoint",
|
||||||
|
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = component.to_dict()
|
||||||
|
assert data == {
|
||||||
|
"type": "haystack.components.generators.azure.AzureOpenAIGenerator",
|
||||||
|
"init_parameters": {
|
||||||
|
"azure_deployment": "gpt-35-turbo",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
"streaming_callback": None,
|
||||||
|
"azure_endpoint": "some-non-existing-endpoint",
|
||||||
|
"organization": None,
|
||||||
|
"system_prompt": None,
|
||||||
|
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||||
|
reason=(
|
||||||
|
"Please export env variables called AZURE_OPENAI_API_KEY containing "
|
||||||
|
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
|
||||||
|
"the Azure OpenAI endpoint URL to run this test."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_live_run(self):
|
||||||
|
component = AzureOpenAIGenerator(organization="HaystackCI")
|
||||||
|
results = component.run("What's the capital of France?")
|
||||||
|
assert len(results["replies"]) == 1
|
||||||
|
assert len(results["meta"]) == 1
|
||||||
|
response: str = results["replies"][0]
|
||||||
|
assert "Paris" in response
|
||||||
|
|
||||||
|
metadata = results["meta"][0]
|
||||||
|
assert "gpt-35-turbo" in metadata["model"]
|
||||||
|
assert metadata["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
assert "usage" in metadata
|
||||||
|
assert "prompt_tokens" in metadata["usage"] and metadata["usage"]["prompt_tokens"] > 0
|
||||||
|
assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0
|
||||||
|
assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0
|
||||||
|
|
||||||
|
# additional tests intentionally omitted as they are covered by test_openai.py
|
Loading…
x
Reference in New Issue
Block a user