mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 21:20:52 +00:00
feat(embedders): Add async support for OpenAI document embedder (#9140)
* feat(embedders): Add async support for OpenAI document embedder * add release notes * resolve review comments * Update releasenotes/notes/openai-document-embedder-async-support-b46f1e84043da366.yaml Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> * Update openai-document-embedder-async-support-b46f1e84043da366.yaml --------- Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com> Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> Co-authored-by: David S. Batista <dsbatista@gmail.com>
This commit is contained in:
parent
c6df8d2c7a
commit
a2f73d134d
@ -6,8 +6,9 @@ import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from more_itertools import batched
|
||||
from openai import APIError, OpenAI
|
||||
from openai import APIError, AsyncOpenAI, OpenAI
|
||||
from tqdm import tqdm
|
||||
from tqdm.asyncio import tqdm as async_tqdm
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
@ -112,13 +113,16 @@ class OpenAIDocumentEmbedder:
|
||||
if max_retries is None:
|
||||
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=api_key.resolve_value(),
|
||||
organization=organization,
|
||||
base_url=api_base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
client_args: Dict[str, Any] = {
|
||||
"api_key": api_key.resolve_value(),
|
||||
"organization": organization,
|
||||
"base_url": api_base_url,
|
||||
"timeout": timeout,
|
||||
"max_retries": max_retries,
|
||||
}
|
||||
|
||||
self.client = OpenAI(**client_args)
|
||||
self.async_client = AsyncOpenAI(**client_args)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -216,6 +220,47 @@ class OpenAIDocumentEmbedder:
|
||||
|
||||
return all_embeddings, meta
|
||||
|
||||
async def _embed_batch_async(
|
||||
self, texts_to_embed: Dict[str, str], batch_size: int
|
||||
) -> Tuple[List[List[float]], Dict[str, Any]]:
|
||||
"""
|
||||
Embed a list of texts in batches asynchronously.
|
||||
"""
|
||||
|
||||
all_embeddings = []
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
batches = list(batched(texts_to_embed.items(), batch_size))
|
||||
if self.progress_bar:
|
||||
batches = async_tqdm(batches, desc="Calculating embeddings")
|
||||
|
||||
for batch in batches:
|
||||
args: Dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]}
|
||||
|
||||
if self.dimensions is not None:
|
||||
args["dimensions"] = self.dimensions
|
||||
|
||||
try:
|
||||
response = await self.async_client.embeddings.create(**args)
|
||||
except APIError as exc:
|
||||
ids = ", ".join(b[0] for b in batch)
|
||||
msg = "Failed embedding of documents {ids} caused by {exc}"
|
||||
logger.exception(msg, ids=ids, exc=exc)
|
||||
continue
|
||||
|
||||
embeddings = [el.embedding for el in response.data]
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
if "model" not in meta:
|
||||
meta["model"] = response.model
|
||||
if "usage" not in meta:
|
||||
meta["usage"] = dict(response.usage)
|
||||
else:
|
||||
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
|
||||
meta["usage"]["total_tokens"] += response.usage.total_tokens
|
||||
|
||||
return all_embeddings, meta
|
||||
|
||||
@component.output_types(documents=List[Document], meta=Dict[str, Any])
|
||||
def run(self, documents: List[Document]):
|
||||
"""
|
||||
@ -243,3 +288,31 @@ class OpenAIDocumentEmbedder:
|
||||
doc.embedding = emb
|
||||
|
||||
return {"documents": documents, "meta": meta}
|
||||
|
||||
@component.output_types(documents=List[Document], meta=Dict[str, Any])
|
||||
async def run_async(self, documents: List[Document]):
|
||||
"""
|
||||
Embeds a list of documents asynchronously.
|
||||
|
||||
:param documents:
|
||||
A list of documents to embed.
|
||||
|
||||
:returns:
|
||||
A dictionary with the following keys:
|
||||
- `documents`: A list of documents with embeddings.
|
||||
- `meta`: Information about the usage of the model.
|
||||
"""
|
||||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
|
||||
raise TypeError(
|
||||
"OpenAIDocumentEmbedder expects a list of Documents as input. "
|
||||
"In case you want to embed a string, please use the OpenAITextEmbedder."
|
||||
)
|
||||
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
|
||||
|
||||
embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
||||
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
doc.embedding = emb
|
||||
|
||||
return {"documents": documents, "meta": meta}
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Added async support to the `OpenAIDocumentEmbedder` component.
|
||||
@ -256,3 +256,33 @@ class TestOpenAIDocumentEmbedder:
|
||||
)
|
||||
|
||||
assert result["meta"]["usage"] == {"prompt_tokens": 15, "total_tokens": 15}, "Usage information does not match"
|
||||
|
||||
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set")
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async(self):
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
model = "text-embedding-ada-002"
|
||||
|
||||
embedder = OpenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ")
|
||||
|
||||
result = await embedder.run_async(documents=docs)
|
||||
documents_with_embeddings = result["documents"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 1536
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
|
||||
assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], (
|
||||
"The model name does not contain 'text' and 'ada'"
|
||||
)
|
||||
|
||||
assert result["meta"]["usage"] == {"prompt_tokens": 15, "total_tokens": 15}, "Usage information does not match"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user