feat: Add Azure embedders support (#6676)

* Add Azure embedders
---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2024-01-05 15:49:25 +01:00 committed by GitHub
parent b7159ad7c2
commit 552f0e394b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 440 additions and 0 deletions

View File

@ -4,6 +4,8 @@ from haystack.components.embedders.sentence_transformers_text_embedder import Se
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder
from haystack.components.embedders.azure_document_embedder import AzureOpenAIDocumentEmbedder
__all__ = [
"HuggingFaceTEITextEmbedder",
@ -12,4 +14,6 @@ __all__ = [
"SentenceTransformersDocumentEmbedder",
"OpenAITextEmbedder",
"OpenAIDocumentEmbedder",
"AzureOpenAITextEmbedder",
"AzureOpenAIDocumentEmbedder",
]

View File

@ -0,0 +1,178 @@
import os
from typing import List, Optional, Dict, Any, Tuple
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
from tqdm import tqdm
from haystack import component, Document, default_to_dict
@component
class AzureOpenAIDocumentEmbedder:
"""
A component for computing Document embeddings using OpenAI models.
The embedding of each Document is stored in the `embedding` field of the Document.
Usage example:
```python
from haystack import Document
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
doc = Document(content="I love pizza!")
document_embedder = AzureOpenAIDocumentEmbedder()
result = document_embedder.run([doc])
print(result['documents'][0].embedding)
# [0.017020374536514282, -0.023255806416273117, ...]
```
"""
def __init__(
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
organization: Optional[str] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Create an AzureOpenAITextEmbedder component.
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/`
:param api_version: The version of the API to use. Defaults to 2023-05-15
:param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of Documents to encode at once.
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments
to keep the logs clean.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""
# if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.organization = organization
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self._client = AzureOpenAI(
api_version=api_version,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.azure_deployment}
def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict(
self,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
organization=self.organization,
api_version=self.api_version,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
"""
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]
text_to_embed = (
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
).replace("\n", " ")
texts_to_embed.append(text_to_embed)
return texts_to_embed
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
"""
Embed a list of texts in batches.
"""
all_embeddings: List[List[float]] = []
meta: Dict[str, Any] = {"model": "", "usage": {"prompt_tokens": 0, "total_tokens": 0}}
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Texts"):
batch = texts_to_embed[i : i + batch_size]
response = self._client.embeddings.create(model=self.azure_deployment, input=batch)
# Append embeddings to the list
all_embeddings.extend(el.embedding for el in response.data)
# Update the meta information only once if it's empty
if not meta["model"]:
meta["model"] = response.model
meta["usage"] = dict(response.usage)
else:
# Update the usage tokens
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
meta["usage"]["total_tokens"] += response.usage.total_tokens
return all_embeddings, meta
@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document]):
"""
Embed a list of Documents. The embedding of each Document is stored in the `embedding` field of the Document.
:param documents: A list of Documents to embed.
"""
if not (isinstance(documents, list) and all(isinstance(doc, Document) for doc in documents)):
raise TypeError("Input must be a list of Document instances. For strings, use AzureOpenAITextEmbedder.")
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
# Assign the corresponding embeddings to each document
for doc, emb in zip(documents, embeddings):
doc.embedding = emb
return {"documents": documents, "meta": meta}

View File

@ -0,0 +1,123 @@
import os
from typing import List, Optional, Dict, Any
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
from haystack import component, default_to_dict, Document
@component
class AzureOpenAITextEmbedder:
"""
A component for embedding strings using OpenAI models.
Usage example:
```python
from haystack.components.embedders import AzureOpenAITextEmbedder
text_to_embed = "I love pizza!"
text_embedder = AzureOpenAITextEmbedder()
print(text_embedder.run(text_to_embed))
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
# 'meta': {'model': 'text-embedding-ada-002-v2',
# 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
```
"""
def __init__(
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
organization: Optional[str] = None,
prefix: str = "",
suffix: str = "",
):
"""
Create an AzureOpenAITextEmbedder component.
:param azure_endpoint: The endpoint of the deployed model, e.g. `https://example-resource.azure.openai.com/`
:param api_version: The version of the API to use. Defaults to 2023-05-15
:param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""
# Why is this here?
# AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
# None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead
# of passing it as a parameter.
azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
self.organization = organization
self.prefix = prefix
self.suffix = suffix
self._client = AzureOpenAI(
api_version=api_version,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization,
)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.azure_deployment}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
"""
return default_to_dict(
self,
azure_endpoint=self.azure_endpoint,
azure_deployment=self.azure_deployment,
organization=self.organization,
api_version=self.api_version,
prefix=self.prefix,
suffix=self.suffix,
)
@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str):
"""Embed a string using AzureOpenAITextEmbedder."""
if not isinstance(text, str):
# Check if input is a list and all elements are instances of Document
if isinstance(text, list) and all(isinstance(elem, Document) for elem in text):
error_message = "Input must be a string. Use AzureOpenAIDocumentEmbedder for a list of Documents."
else:
error_message = "Input must be a string."
raise TypeError(error_message)
# Preprocess the text by adding prefixes/suffixes
# finally, replace newlines as recommended by OpenAI docs
processed_text = f"{self.prefix}{text}{self.suffix}".replace("\n", " ")
response = self._client.embeddings.create(model=self.azure_deployment, input=processed_text)
return {
"embedding": response.data[0].embedding,
"meta": {"model": response.model, "usage": dict(response.usage)},
}

View File

@ -0,0 +1,5 @@
---
features:
- |
Adds AzureOpenAIDocumentEmbedder and AzureOpenAITextEmbedder as new embedders. These embedders are very similar to
their OpenAI counterparts, but they use the Azure API instead of the OpenAI API.

View File

@ -0,0 +1,76 @@
import os
import pytest
from haystack import Document
from haystack.components.embedders import AzureOpenAIDocumentEmbedder
class TestAzureOpenAIDocumentEmbedder:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
embedder = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
assert embedder.azure_deployment == "text-embedding-ada-002"
assert embedder.organization is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
def test_to_dict(self):
component = AzureOpenAIDocumentEmbedder(
api_key="fake-api-key", azure_endpoint="https://example-resource.azure.openai.com/"
)
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
"init_parameters": {
"api_version": "2023-05-15",
"azure_deployment": "text-embedding-ada-002",
"azure_endpoint": "https://example-resource.azure.openai.com/",
"organization": None,
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
}
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
reason=(
"Please export env variables called AZURE_OPENAI_API_KEY containing "
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
"the Azure OpenAI endpoint URL to run this test."
),
)
def test_run(self):
docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]
# the default model is text-embedding-ada-002 even if we don't specify it, but let's be explicit
embedder = AzureOpenAIDocumentEmbedder(
azure_deployment="text-embedding-ada-002",
meta_fields_to_embed=["topic"],
embedding_separator=" | ",
organization="HaystackCI",
)
result = embedder.run(documents=docs)
documents_with_embeddings = result["documents"]
metadata = result["meta"]
assert isinstance(documents_with_embeddings, list)
assert len(documents_with_embeddings) == len(docs)
for doc in documents_with_embeddings:
assert isinstance(doc, Document)
assert isinstance(doc.embedding, list)
assert len(doc.embedding) == 1536
assert all(isinstance(x, float) for x in doc.embedding)
assert metadata == {"model": "ada", "usage": {"prompt_tokens": 15, "total_tokens": 15}}

View File

@ -0,0 +1,54 @@
import os
import pytest
from haystack.components.embedders import AzureOpenAITextEmbedder
class TestAzureOpenAITextEmbedder:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
embedder = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
assert embedder._client.api_key == "fake-api-key"
assert embedder.azure_deployment == "text-embedding-ada-002"
assert embedder.organization is None
assert embedder.prefix == ""
assert embedder.suffix == ""
def test_to_dict(self):
component = AzureOpenAITextEmbedder(
api_key="fake-api-key", azure_endpoint="https://example-resource.azure.openai.com/"
)
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder",
"init_parameters": {
"azure_deployment": "text-embedding-ada-002",
"organization": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",
"api_version": "2023-05-15",
"prefix": "",
"suffix": "",
},
}
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
reason=(
"Please export env variables called AZURE_OPENAI_API_KEY containing "
"the Azure OpenAI key, AZURE_OPENAI_ENDPOINT containing "
"the Azure OpenAI endpoint URL to run this test."
),
)
def test_run(self):
# the default model is text-embedding-ada-002 even if we don't specify it, but let's be explicit
embedder = AzureOpenAITextEmbedder(
azure_deployment="text-embedding-ada-002", prefix="prefix ", suffix=" suffix", organization="HaystackCI"
)
result = embedder.run(text="The food was delicious")
assert len(result["embedding"]) == 1536
assert all(isinstance(x, float) for x in result["embedding"])
assert result["meta"] == {"model": "ada", "usage": {"prompt_tokens": 6, "total_tokens": 6}}