diff --git a/haystack/components/embedders/openai_document_embedder.py b/haystack/components/embedders/openai_document_embedder.py index b45b3521d..92629e1ba 100644 --- a/haystack/components/embedders/openai_document_embedder.py +++ b/haystack/components/embedders/openai_document_embedder.py @@ -6,8 +6,9 @@ import os from typing import Any, Dict, List, Optional, Tuple from more_itertools import batched -from openai import APIError, OpenAI +from openai import APIError, AsyncOpenAI, OpenAI from tqdm import tqdm +from tqdm.asyncio import tqdm as async_tqdm from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace @@ -112,13 +113,16 @@ class OpenAIDocumentEmbedder: if max_retries is None: max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5")) - self.client = OpenAI( - api_key=api_key.resolve_value(), - organization=organization, - base_url=api_base_url, - timeout=timeout, - max_retries=max_retries, - ) + client_args: Dict[str, Any] = { + "api_key": api_key.resolve_value(), + "organization": organization, + "base_url": api_base_url, + "timeout": timeout, + "max_retries": max_retries, + } + + self.client = OpenAI(**client_args) + self.async_client = AsyncOpenAI(**client_args) def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -216,6 +220,47 @@ class OpenAIDocumentEmbedder: return all_embeddings, meta + async def _embed_batch_async( + self, texts_to_embed: Dict[str, str], batch_size: int + ) -> Tuple[List[List[float]], Dict[str, Any]]: + """ + Embed a list of texts in batches asynchronously. + """ + + all_embeddings = [] + meta: Dict[str, Any] = {} + + batches = list(batched(texts_to_embed.items(), batch_size)) + if self.progress_bar: + batches = async_tqdm(batches, desc="Calculating embeddings") + + for batch in batches: + args: Dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]} + + if self.dimensions is not None: + args["dimensions"] = self.dimensions + + try: + response = await self.async_client.embeddings.create(**args) + except APIError as exc: + ids = ", ".join(b[0] for b in batch) + msg = "Failed embedding of documents {ids} caused by {exc}" + logger.exception(msg, ids=ids, exc=exc) + continue + + embeddings = [el.embedding for el in response.data] + all_embeddings.extend(embeddings) + + if "model" not in meta: + meta["model"] = response.model + if "usage" not in meta: + meta["usage"] = dict(response.usage) + else: + meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens + meta["usage"]["total_tokens"] += response.usage.total_tokens + + return all_embeddings, meta + @component.output_types(documents=List[Document], meta=Dict[str, Any]) def run(self, documents: List[Document]): """ @@ -243,3 +288,31 @@ class OpenAIDocumentEmbedder: doc.embedding = emb return {"documents": documents, "meta": meta} + + @component.output_types(documents=List[Document], meta=Dict[str, Any]) + async def run_async(self, documents: List[Document]): + """ + Embeds a list of documents asynchronously. + + :param documents: + A list of documents to embed. + + :returns: + A dictionary with the following keys: + - `documents`: A list of documents with embeddings. + - `meta`: Information about the usage of the model. + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + raise TypeError( + "OpenAIDocumentEmbedder expects a list of Documents as input. " + "In case you want to embed a string, please use the OpenAITextEmbedder." + ) + + texts_to_embed = self._prepare_texts_to_embed(documents=documents) + + embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents, "meta": meta} diff --git a/releasenotes/notes/openai-document-embedder-async-support-b46f1e84043da366.yaml b/releasenotes/notes/openai-document-embedder-async-support-b46f1e84043da366.yaml new file mode 100644 index 000000000..97fe98d9e --- /dev/null +++ b/releasenotes/notes/openai-document-embedder-async-support-b46f1e84043da366.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Added async support to the `OpenAIDocumentEmbedder` component. diff --git a/test/components/embedders/test_openai_document_embedder.py b/test/components/embedders/test_openai_document_embedder.py index 7d43bcfa8..1e8501873 100644 --- a/test/components/embedders/test_openai_document_embedder.py +++ b/test/components/embedders/test_openai_document_embedder.py @@ -256,3 +256,33 @@ class TestOpenAIDocumentEmbedder: ) assert result["meta"]["usage"] == {"prompt_tokens": 15, "total_tokens": 15}, "Usage information does not match" + + @pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set") + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_async(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + model = "text-embedding-ada-002" + + embedder = OpenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ") + + result = await embedder.run_async(documents=docs) + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 1536 + assert all(isinstance(x, float) for x in doc.embedding) + + assert "text" in result["meta"]["model"] and "ada" in result["meta"]["model"], ( + "The model name does not contain 'text' and 'ada'" + ) + + assert result["meta"]["usage"] == {"prompt_tokens": 15, "total_tokens": 15}, "Usage information does not match"