feat(embedders): Add async support for OpenAI document embedder (#9140)

* feat(embedders): Add async support for OpenAI document embedder

* add release notes

* resolve review comments

* Update releasenotes/notes/openai-document-embedder-async-support-b46f1e84043da366.yaml

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>

* Update openai-document-embedder-async-support-b46f1e84043da366.yaml

---------

Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
Mohammed Abdul Razak Wahab 2025-04-04 17:25:59 +05:30 committed by GitHub
parent c6df8d2c7a
commit a2f73d134d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 115 additions and 8 deletions

View File

@ -6,8 +6,9 @@ import os
from typing import Any, Dict, List, Optional, Tuple
from more_itertools import batched
from openai import APIError, OpenAI
from openai import APIError, AsyncOpenAI, OpenAI
from tqdm import tqdm
from tqdm.asyncio import tqdm as async_tqdm
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.utils import Secret, deserialize_secrets_inplace
@ -112,13 +113,16 @@ class OpenAIDocumentEmbedder:
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
self.client = OpenAI(
api_key=api_key.resolve_value(),
organization=organization,
base_url=api_base_url,
timeout=timeout,
max_retries=max_retries,
)
client_args: Dict[str, Any] = {
"api_key": api_key.resolve_value(),
"organization": organization,
"base_url": api_base_url,
"timeout": timeout,
"max_retries": max_retries,
}
self.client = OpenAI(**client_args)
self.async_client = AsyncOpenAI(**client_args)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -216,6 +220,47 @@ class OpenAIDocumentEmbedder:
return all_embeddings, meta
async def _embed_batch_async(
self, texts_to_embed: Dict[str, str], batch_size: int
) -> Tuple[List[List[float]], Dict[str, Any]]:
"""
Embed a list of texts in batches asynchronously.
"""
all_embeddings = []
meta: Dict[str, Any] = {}
batches = list(batched(texts_to_embed.items(), batch_size))
if self.progress_bar:
batches = async_tqdm(batches, desc="Calculating embeddings")
for batch in batches:
args: Dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]}
if self.dimensions is not None:
args["dimensions"] = self.dimensions
try:
response = await self.async_client.embeddings.create(**args)
except APIError as exc:
ids = ", ".join(b[0] for b in batch)
msg = "Failed embedding of documents {ids} caused by {exc}"
logger.exception(msg, ids=ids, exc=exc)
continue
embeddings = [el.embedding for el in response.data]
all_embeddings.extend(embeddings)
if "model" not in meta:
meta["model"] = response.model
if "usage" not in meta:
meta["usage"] = dict(response.usage)
else:
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
meta["usage"]["total_tokens"] += response.usage.total_tokens
return all_embeddings, meta
@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document]):
"""
@ -243,3 +288,31 @@ class OpenAIDocumentEmbedder:
doc.embedding = emb
return {"documents": documents, "meta": meta}
@component.output_types(documents=List[Document], meta=Dict[str, Any])
async def run_async(self, documents: List[Document]):
"""
Embeds a list of documents asynchronously.
:param documents:
A list of documents to embed.
:returns:
A dictionary with the following keys:
- `documents`: A list of documents with embeddings.
- `meta`: Information about the usage of the model.
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
raise TypeError(
"OpenAIDocumentEmbedder expects a list of Documents as input. "
"In case you want to embed a string, please use the OpenAITextEmbedder."
)
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
for doc, emb in zip(documents, embeddings):
doc.embedding = emb
return {"documents": documents, "meta": meta}

View File

@ -0,0 +1,4 @@
---
features:
- |
Added async support to the `OpenAIDocumentEmbedder` component.

View File

@ -256,3 +256,33 @@ class TestOpenAIDocumentEmbedder:
)
assert result["meta"]["usage"] == {"prompt_tokens": 15, "total_tokens": 15}, "Usage information does not match"
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set")
@pytest.mark.integration
@pytest.mark.asyncio
async def test_run_async(self):
docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]
model = "text-embedding-ada-002"
embedder = OpenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ")
result = await embedder.run_async(documents=docs)
documents_with_embeddings = result["documents"]
assert isinstance(documents_with_embeddings, list)
assert len(documents_with_embeddings) == len(docs)
for doc in documents_with_embeddings:
assert isinstance(doc, Document)
assert isinstance(doc.embedding, list)
assert len(doc.embedding) == 1536
assert all(isinstance(x, float) for x in doc.embedding)
assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], (
"The model name does not contain 'text' and 'ada'"
)
assert result["meta"]["usage"] == {"prompt_tokens": 15, "total_tokens": 15}, "Usage information does not match"