feat: Add raise_on_failure boolean parameter to OpenAIDocumentEmbedder and AzureOpenAIDocumentEmbedder (#9474)

* Add raise_on_failure to OpenAIDocumentEmbedder

* Add reno

* Add parameter to Azure Doc embedder as well

* Fix bug

* Update reno

* PR comments

* update reno
This commit is contained in:
Sebastian Husch Lee 2025-06-03 12:22:34 +02:00 committed by GitHub
parent 5fcd7c4732
commit ce0917e586
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 112 additions and 17 deletions

View File

@ -59,6 +59,7 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
default_headers: Optional[Dict[str, str]] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
http_client_kwargs: Optional[Dict[str, Any]] = None,
raise_on_failure: bool = False,
):
"""
Creates an AzureOpenAIDocumentEmbedder component.
@ -109,6 +110,9 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
:param raise_on_failure:
Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
and continue processing the remaining documents. If `True`, it will raise an exception on failure.
"""
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
# with the API.
@ -140,6 +144,7 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
self.default_headers = default_headers or {}
self.azure_ad_token_provider = azure_ad_token_provider
self.http_client_kwargs = http_client_kwargs
self.raise_on_failure = raise_on_failure
client_args: Dict[str, Any] = {
"api_version": api_version,
@ -191,6 +196,7 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
default_headers=self.default_headers,
azure_ad_token_provider=azure_ad_token_provider_name,
http_client_kwargs=self.http_client_kwargs,
raise_on_failure=self.raise_on_failure,
)
@classmethod

View File

@ -39,7 +39,7 @@ class OpenAIDocumentEmbedder:
```
"""
def __init__( # pylint: disable=too-many-positional-arguments
def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "text-embedding-ada-002",
@ -55,6 +55,8 @@ class OpenAIDocumentEmbedder:
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
http_client_kwargs: Optional[Dict[str, Any]] = None,
*,
raise_on_failure: bool = False,
):
"""
Creates an OpenAIDocumentEmbedder component.
@ -100,6 +102,9 @@ class OpenAIDocumentEmbedder:
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
:param raise_on_failure:
Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
and continue processing the remaining documents. If `True`, it will raise an exception on failure.
"""
self.api_key = api_key
self.model = model
@ -115,6 +120,7 @@ class OpenAIDocumentEmbedder:
self.timeout = timeout
self.max_retries = max_retries
self.http_client_kwargs = http_client_kwargs
self.raise_on_failure = raise_on_failure
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
@ -163,6 +169,7 @@ class OpenAIDocumentEmbedder:
timeout=self.timeout,
max_retries=self.max_retries,
http_client_kwargs=self.http_client_kwargs,
raise_on_failure=self.raise_on_failure,
)
@classmethod
@ -194,12 +201,14 @@ class OpenAIDocumentEmbedder:
return texts_to_embed
def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
def _embed_batch(
self, texts_to_embed: Dict[str, str], batch_size: int
) -> Tuple[Dict[str, List[float]], Dict[str, Any]]:
"""
Embed a list of texts in batches.
"""
all_embeddings = []
doc_ids_to_embeddings: Dict[str, List[float]] = {}
meta: Dict[str, Any] = {}
for batch in tqdm(
batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
@ -215,10 +224,12 @@ class OpenAIDocumentEmbedder:
ids = ", ".join(b[0] for b in batch)
msg = "Failed embedding of documents {ids} caused by {exc}"
logger.exception(msg, ids=ids, exc=exc)
if self.raise_on_failure:
raise exc
continue
embeddings = [el.embedding for el in response.data]
all_embeddings.extend(embeddings)
doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings)))
if "model" not in meta:
meta["model"] = response.model
@ -228,16 +239,16 @@ class OpenAIDocumentEmbedder:
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
meta["usage"]["total_tokens"] += response.usage.total_tokens
return all_embeddings, meta
return doc_ids_to_embeddings, meta
async def _embed_batch_async(
self, texts_to_embed: Dict[str, str], batch_size: int
) -> Tuple[List[List[float]], Dict[str, Any]]:
) -> Tuple[Dict[str, List[float]], Dict[str, Any]]:
"""
Embed a list of texts in batches asynchronously.
"""
all_embeddings = []
doc_ids_to_embeddings: Dict[str, List[float]] = {}
meta: Dict[str, Any] = {}
batches = list(batched(texts_to_embed.items(), batch_size))
@ -256,10 +267,12 @@ class OpenAIDocumentEmbedder:
ids = ", ".join(b[0] for b in batch)
msg = "Failed embedding of documents {ids} caused by {exc}"
logger.exception(msg, ids=ids, exc=exc)
if self.raise_on_failure:
raise exc
continue
embeddings = [el.embedding for el in response.data]
all_embeddings.extend(embeddings)
doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings)))
if "model" not in meta:
meta["model"] = response.model
@ -269,7 +282,7 @@ class OpenAIDocumentEmbedder:
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
meta["usage"]["total_tokens"] += response.usage.total_tokens
return all_embeddings, meta
return doc_ids_to_embeddings, meta
@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document]):
@ -292,12 +305,13 @@ class OpenAIDocumentEmbedder:
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
for doc, emb in zip(documents, embeddings):
doc.embedding = emb
doc_id_to_document = {doc.id: doc for doc in documents}
for doc_id, emb in doc_ids_to_embeddings.items():
doc_id_to_document[doc_id].embedding = emb
return {"documents": documents, "meta": meta}
return {"documents": list(doc_id_to_document.values()), "meta": meta}
@component.output_types(documents=List[Document], meta=Dict[str, Any])
async def run_async(self, documents: List[Document]):
@ -320,9 +334,12 @@ class OpenAIDocumentEmbedder:
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
doc_ids_to_embeddings, meta = await self._embed_batch_async(
texts_to_embed=texts_to_embed, batch_size=self.batch_size
)
for doc, emb in zip(documents, embeddings):
doc.embedding = emb
doc_id_to_document = {doc.id: doc for doc in documents}
for doc_id, emb in doc_ids_to_embeddings.items():
doc_id_to_document[doc_id].embedding = emb
return {"documents": documents, "meta": meta}
return {"documents": list(doc_id_to_document.values()), "meta": meta}

View File

@ -0,0 +1,7 @@
---
features:
- |
Added a raise_on_failure boolean parameter to OpenAIDocumentEmbedder and AzureOpenAIDocumentEmbedder. If set to True then the component will raise an exception when there is an error with the API request. It is set to False by default to so the previous behavior of logging an exception and continuing is still the default.
fixes:
- |
Fix bug where if raise_on_failure=False and an error occurs mid-batch that the following embeddings would be paired with the wrong documents.

View File

@ -78,6 +78,7 @@ class TestAzureOpenAIDocumentEmbedder:
"default_headers": {},
"azure_ad_token_provider": None,
"http_client_kwargs": None,
"raise_on_failure": False,
},
}
@ -95,6 +96,7 @@ class TestAzureOpenAIDocumentEmbedder:
default_headers={"x-custom-header": "custom-value"},
azure_ad_token_provider=default_azure_ad_token_provider,
http_client_kwargs={"proxy": "http://example.com:3128", "verify": False},
raise_on_failure=True,
)
data = component.to_dict()
assert data == {
@ -118,6 +120,7 @@ class TestAzureOpenAIDocumentEmbedder:
"default_headers": {"x-custom-header": "custom-value"},
"azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
"http_client_kwargs": {"proxy": "http://example.com:3128", "verify": False},
"raise_on_failure": True,
},
}
@ -144,6 +147,7 @@ class TestAzureOpenAIDocumentEmbedder:
"default_headers": {},
"azure_ad_token_provider": None,
"http_client_kwargs": None,
"raise_on_failure": False,
},
}
component = AzureOpenAIDocumentEmbedder.from_dict(data)
@ -157,6 +161,7 @@ class TestAzureOpenAIDocumentEmbedder:
assert component.default_headers == {}
assert component.azure_ad_token_provider is None
assert component.http_client_kwargs is None
assert component.raise_on_failure is False
def test_from_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
@ -181,6 +186,7 @@ class TestAzureOpenAIDocumentEmbedder:
"default_headers": {"x-custom-header": "custom-value"},
"azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
"http_client_kwargs": {"proxy": "http://example.com:3128", "verify": False},
"raise_on_failure": True,
},
}
component = AzureOpenAIDocumentEmbedder.from_dict(data)
@ -194,6 +200,7 @@ class TestAzureOpenAIDocumentEmbedder:
assert component.default_headers == {"x-custom-header": "custom-value"}
assert component.azure_ad_token_provider is not None
assert component.http_client_kwargs == {"proxy": "http://example.com:3128", "verify": False}
assert component.raise_on_failure is True
def test_embed_batch_handles_exceptions_gracefully(self, caplog):
embedder = AzureOpenAIDocumentEmbedder(
@ -215,6 +222,22 @@ class TestAzureOpenAIDocumentEmbedder:
assert len(caplog.records) == 1
assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.text
def test_embed_batch_raises_exception_on_failure(self):
embedder = AzureOpenAIDocumentEmbedder(
azure_endpoint="https://test.openai.azure.com",
api_key=Secret.from_token("fake-api-key"),
azure_deployment="text-embedding-ada-002",
raise_on_failure=True,
)
fake_texts_to_embed = {"1": "text1", "2": "text2"}
with patch.object(
embedder.client.embeddings,
"create",
side_effect=APIError(message="Mocked error", request=Mock(), body=None),
):
with pytest.raises(APIError, match="Mocked error"):
embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=2)
@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),

View File

@ -125,6 +125,7 @@ class TestOpenAIDocumentEmbedder:
"embedding_separator": "\n",
"timeout": None,
"max_retries": None,
"raise_on_failure": False,
},
}
@ -143,6 +144,7 @@ class TestOpenAIDocumentEmbedder:
embedding_separator=" | ",
timeout=10.0,
max_retries=2,
raise_on_failure=True,
)
data = component.to_dict()
assert data == {
@ -162,6 +164,7 @@ class TestOpenAIDocumentEmbedder:
"embedding_separator": " | ",
"timeout": 10.0,
"max_retries": 2,
"raise_on_failure": True,
},
}
@ -237,6 +240,45 @@ class TestOpenAIDocumentEmbedder:
assert len(caplog.records) == 1
assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.records[0].msg
def test_run_handles_exceptions_gracefully(self, caplog):
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake_api_key"), batch_size=1)
docs = [
Document(content="I love cheese", meta={"topic": "Cuisine"}),
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
]
# Create a successful response for the second call
successful_response = Mock()
successful_response.data = [
Mock(embedding=[0.4, 0.5, 0.6]) # Mock embedding for second doc
]
successful_response.model = "text-embedding-ada-002"
successful_response.usage = {"prompt_tokens": 10, "total_tokens": 10}
with patch.object(
embedder.client.embeddings,
"create",
side_effect=[
APIError(message="Mocked error", request=Mock(), body=None), # First call fails
successful_response, # Second call succeeds
],
):
result = embedder.run(documents=docs)
assert len(result["documents"]) == 2
assert result["documents"][0].embedding is None
assert result["documents"][1].embedding == [0.4, 0.5, 0.6]
def test_embed_batch_raises_exception_on_failure(self):
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake_api_key"), raise_on_failure=True)
fake_texts_to_embed = {"1": "text1", "2": "text2"}
with patch.object(
embedder.client.embeddings,
"create",
side_effect=APIError(message="Mocked error", request=Mock(), body=None),
):
with pytest.raises(APIError, match="Mocked error"):
embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=2)
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set")
@pytest.mark.integration
def test_run(self):