mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-03 11:19:57 +00:00
feat: Add Azure embedders support (#6676)
* Add Azure embedders --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
parent
b7159ad7c2
commit
552f0e394b
@ -4,6 +4,8 @@ from haystack.components.embedders.sentence_transformers_text_embedder import Se
|
||||
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
|
||||
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
|
||||
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
|
||||
from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder
|
||||
from haystack.components.embedders.azure_document_embedder import AzureOpenAIDocumentEmbedder
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceTEITextEmbedder",
|
||||
@ -12,4 +14,6 @@ __all__ = [
|
||||
"SentenceTransformersDocumentEmbedder",
|
||||
"OpenAITextEmbedder",
|
||||
"OpenAIDocumentEmbedder",
|
||||
"AzureOpenAITextEmbedder",
|
||||
"AzureOpenAIDocumentEmbedder",
|
||||
]
|
||||
|
||||
178
haystack/components/embedders/azure_document_embedder.py
Normal file
178
haystack/components/embedders/azure_document_embedder.py
Normal file
@ -0,0 +1,178 @@
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack import component, Document, default_to_dict
|
||||
|
||||
|
||||
@component
|
||||
class AzureOpenAIDocumentEmbedder:
|
||||
"""
|
||||
A component for computing Document embeddings using OpenAI models.
|
||||
The embedding of each Document is stored in the `embedding` field of the Document.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack import Document
|
||||
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
|
||||
|
||||
doc = Document(content="I love pizza!")
|
||||
|
||||
document_embedder = AzureOpenAIDocumentEmbedder()
|
||||
|
||||
result = document_embedder.run([doc])
|
||||
print(result['documents'][0].embedding)
|
||||
|
||||
# [0.017020374536514282, -0.023255806416273117, ...]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
azure_endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = "2023-05-15",
|
||||
azure_deployment: str = "text-embedding-ada-002",
|
||||
api_key: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
|
||||
organization: Optional[str] = None,
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
batch_size: int = 32,
|
||||
progress_bar: bool = True,
|
||||
meta_fields_to_embed: Optional[List[str]] = None,
|
||||
embedding_separator: str = "\n",
|
||||
):
|
||||
"""
|
||||
Create an AzureOpenAITextEmbedder component.
|
||||
|
||||
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/`
|
||||
:param api_version: The version of the API to use. Defaults to 2023-05-15
|
||||
:param azure_deployment: The deployment of the model, usually the model name.
|
||||
:param api_key: The API key to use for authentication.
|
||||
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
|
||||
on every request.
|
||||
:param organization: The Organization ID, defaults to `None`. See
|
||||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
|
||||
:param prefix: A string to add to the beginning of each text.
|
||||
:param suffix: A string to add to the end of each text.
|
||||
:param batch_size: Number of Documents to encode at once.
|
||||
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments
|
||||
to keep the logs clean.
|
||||
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
|
||||
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
|
||||
"""
|
||||
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
|
||||
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
if not azure_endpoint:
|
||||
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
|
||||
|
||||
self.api_version = api_version
|
||||
self.azure_endpoint = azure_endpoint
|
||||
self.azure_deployment = azure_deployment
|
||||
self.organization = organization
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
|
||||
self._client = AzureOpenAI(
|
||||
api_version=api_version,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_deployment=azure_deployment,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
organization=organization,
|
||||
)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
return {"model": self.azure_deployment}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
|
||||
to the constructor.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
azure_endpoint=self.azure_endpoint,
|
||||
azure_deployment=self.azure_deployment,
|
||||
organization=self.organization,
|
||||
api_version=self.api_version,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
batch_size=self.batch_size,
|
||||
progress_bar=self.progress_bar,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
)
|
||||
|
||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
||||
"""
|
||||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
|
||||
"""
|
||||
texts_to_embed = []
|
||||
for doc in documents:
|
||||
meta_values_to_embed = [
|
||||
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
|
||||
]
|
||||
|
||||
text_to_embed = (
|
||||
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
|
||||
).replace("\n", " ")
|
||||
|
||||
texts_to_embed.append(text_to_embed)
|
||||
return texts_to_embed
|
||||
|
||||
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
|
||||
"""
|
||||
Embed a list of texts in batches.
|
||||
"""
|
||||
|
||||
all_embeddings: List[List[float]] = []
|
||||
meta: Dict[str, Any] = {"model": "", "usage": {"prompt_tokens": 0, "total_tokens": 0}}
|
||||
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Texts"):
|
||||
batch = texts_to_embed[i : i + batch_size]
|
||||
response = self._client.embeddings.create(model=self.azure_deployment, input=batch)
|
||||
|
||||
# Append embeddings to the list
|
||||
all_embeddings.extend(el.embedding for el in response.data)
|
||||
|
||||
# Update the meta information only once if it's empty
|
||||
if not meta["model"]:
|
||||
meta["model"] = response.model
|
||||
meta["usage"] = dict(response.usage)
|
||||
else:
|
||||
# Update the usage tokens
|
||||
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
|
||||
meta["usage"]["total_tokens"] += response.usage.total_tokens
|
||||
|
||||
return all_embeddings, meta
|
||||
|
||||
@component.output_types(documents=List[Document], meta=Dict[str, Any])
|
||||
def run(self, documents: List[Document]):
|
||||
"""
|
||||
Embed a list of Documents. The embedding of each Document is stored in the `embedding` field of the Document.
|
||||
|
||||
:param documents: A list of Documents to embed.
|
||||
"""
|
||||
if not (isinstance(documents, list) and all(isinstance(doc, Document) for doc in documents)):
|
||||
raise TypeError("Input must be a list of Document instances. For strings, use AzureOpenAITextEmbedder.")
|
||||
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
|
||||
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
||||
|
||||
# Assign the corresponding embeddings to each document
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
doc.embedding = emb
|
||||
|
||||
return {"documents": documents, "meta": meta}
|
||||
123
haystack/components/embedders/azure_text_embedder.py
Normal file
123
haystack/components/embedders/azure_text_embedder.py
Normal file
@ -0,0 +1,123 @@
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
||||
|
||||
from haystack import component, default_to_dict, Document
|
||||
|
||||
|
||||
@component
|
||||
class AzureOpenAITextEmbedder:
|
||||
"""
|
||||
A component for embedding strings using OpenAI models.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack.components.embedders import AzureOpenAITextEmbedder
|
||||
|
||||
text_to_embed = "I love pizza!"
|
||||
|
||||
text_embedder = AzureOpenAITextEmbedder()
|
||||
|
||||
print(text_embedder.run(text_to_embed))
|
||||
|
||||
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
|
||||
# 'meta': {'model': 'text-embedding-ada-002-v2',
|
||||
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
azure_endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = "2023-05-15",
|
||||
azure_deployment: str = "text-embedding-ada-002",
|
||||
api_key: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
|
||||
organization: Optional[str] = None,
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
):
|
||||
"""
|
||||
Create an AzureOpenAITextEmbedder component.
|
||||
|
||||
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/`
|
||||
:param api_version: The version of the API to use. Defaults to 2023-05-15
|
||||
:param azure_deployment: The deployment of the model, usually the model name.
|
||||
:param api_key: The API key to use for authentication.
|
||||
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
|
||||
on every request.
|
||||
:param organization: The Organization ID, defaults to `None`. See
|
||||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
|
||||
:param prefix: A string to add to the beginning of each text.
|
||||
:param suffix: A string to add to the end of each text.
|
||||
"""
|
||||
# Why is this here?
|
||||
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
|
||||
# None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead
|
||||
# of passing it as a parameter.
|
||||
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
if not azure_endpoint:
|
||||
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
|
||||
|
||||
self.api_version = api_version
|
||||
self.azure_endpoint = azure_endpoint
|
||||
self.azure_deployment = azure_deployment
|
||||
self.organization = organization
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
|
||||
self._client = AzureOpenAI(
|
||||
api_version=api_version,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_deployment=azure_deployment,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
organization=organization,
|
||||
)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
"""
|
||||
return {"model": self.azure_deployment}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
:return: The serialized component as a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
azure_endpoint=self.azure_endpoint,
|
||||
azure_deployment=self.azure_deployment,
|
||||
organization=self.organization,
|
||||
api_version=self.api_version,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
)
|
||||
|
||||
@component.output_types(embedding=List[float], meta=Dict[str, Any])
|
||||
def run(self, text: str):
|
||||
"""Embed a string using AzureOpenAITextEmbedder."""
|
||||
if not isinstance(text, str):
|
||||
# Check if input is a list and all elements are instances of Document
|
||||
if isinstance(text, list) and all(isinstance(elem, Document) for elem in text):
|
||||
error_message = "Input must be a string. Use AzureOpenAIDocumentEmbedder for a list of Documents."
|
||||
else:
|
||||
error_message = "Input must be a string."
|
||||
raise TypeError(error_message)
|
||||
|
||||
# Preprocess the text by adding prefixes/suffixes
|
||||
# finally, replace newlines as recommended by OpenAI docs
|
||||
processed_text = f"{self.prefix}{text}{self.suffix}".replace("\n", " ")
|
||||
|
||||
response = self._client.embeddings.create(model=self.azure_deployment, input=processed_text)
|
||||
|
||||
return {
|
||||
"embedding": response.data[0].embedding,
|
||||
"meta": {"model": response.model, "usage": dict(response.usage)},
|
||||
}
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Adds AzureOpenAIDocumentEmbedder and AzureOpenAITextEmbedder as new embedders. These embedders are very similar to
|
||||
their OpenAI counterparts, but they use the Azure API instead of the OpenAI API.
|
||||
76
test/components/embedders/test_azure_document_embedder.py
Normal file
76
test/components/embedders/test_azure_document_embedder.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import Document
|
||||
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
|
||||
|
||||
|
||||
class TestAzureOpenAIDocumentEmbedder:
|
||||
def test_init_default(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
|
||||
assert embedder.azure_deployment == "text-embedding-ada-002"
|
||||
assert embedder.organization is None
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.meta_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
|
||||
def test_to_dict(self):
|
||||
component = AzureOpenAIDocumentEmbedder(
|
||||
api_key="fake-api-key", azure_endpoint="https://example-resource.azure.openai.com/"
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"api_version": "2023-05-15",
|
||||
"azure_deployment": "text-embedding-ada-002",
|
||||
"azure_endpoint": "https://example-resource.azure.openai.com/",
|
||||
"organization": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
reason=(
|
||||
"Please export env variables called AZURE_OPENAI_API_KEY containing "
|
||||
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
|
||||
"the Azure OpenAI endpoint URL to run this test."
|
||||
),
|
||||
)
|
||||
def test_run(self):
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
# the default model is text-embedding-ada-002 even if we don't specify it, but let's be explicit
|
||||
embedder = AzureOpenAIDocumentEmbedder(
|
||||
azure_deployment="text-embedding-ada-002",
|
||||
meta_fields_to_embed=["topic"],
|
||||
embedding_separator=" | ",
|
||||
organization="HaystackCI",
|
||||
)
|
||||
|
||||
result = embedder.run(documents=docs)
|
||||
documents_with_embeddings = result["documents"]
|
||||
metadata = result["meta"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 1536
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
assert metadata == {"model": "ada", "usage": {"prompt_tokens": 15, "total_tokens": 15}}
|
||||
54
test/components/embedders/test_azure_text_embedder.py
Normal file
54
test/components/embedders/test_azure_text_embedder.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.components.embedders import AzureOpenAITextEmbedder
|
||||
|
||||
|
||||
class TestAzureOpenAITextEmbedder:
|
||||
def test_init_default(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
|
||||
|
||||
assert embedder._client.api_key == "fake-api-key"
|
||||
assert embedder.azure_deployment == "text-embedding-ada-002"
|
||||
assert embedder.organization is None
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
|
||||
def test_to_dict(self):
|
||||
component = AzureOpenAITextEmbedder(
|
||||
api_key="fake-api-key", azure_endpoint="https://example-resource.azure.openai.com/"
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder",
|
||||
"init_parameters": {
|
||||
"azure_deployment": "text-embedding-ada-002",
|
||||
"organization": None,
|
||||
"azure_endpoint": "https://example-resource.azure.openai.com/",
|
||||
"api_version": "2023-05-15",
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
reason=(
|
||||
"Please export env variables called AZURE_OPENAI_API_KEY containing "
|
||||
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
|
||||
"the Azure OpenAI endpoint URL to run this test."
|
||||
),
|
||||
)
|
||||
def test_run(self):
|
||||
# the default model is text-embedding-ada-002 even if we don't specify it, but let's be explicit
|
||||
embedder = AzureOpenAITextEmbedder(
|
||||
azure_deployment="text-embedding-ada-002", prefix="prefix ", suffix=" suffix", organization="HaystackCI"
|
||||
)
|
||||
result = embedder.run(text="The food was delicious")
|
||||
|
||||
assert len(result["embedding"]) == 1536
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
assert result["meta"] == {"model": "ada", "usage": {"prompt_tokens": 6, "total_tokens": 6}}
|
||||
Loading…
x
Reference in New Issue
Block a user