mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
feat: Add raise_on_failure
boolean parameter to OpenAIDocumentEmbedder
and AzureOpenAIDocumentEmbedder
(#9474)
* Add raise_on_failure to OpenAIDocumentEmbedder * Add reno * Add parameter to Azure Doc embedder as well * Fix bug * Update reno * PR comments * update reno
This commit is contained in:
parent
5fcd7c4732
commit
ce0917e586
@ -59,6 +59,7 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
|
||||
default_headers: Optional[Dict[str, str]] = None,
|
||||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
|
||||
http_client_kwargs: Optional[Dict[str, Any]] = None,
|
||||
raise_on_failure: bool = False,
|
||||
):
|
||||
"""
|
||||
Creates an AzureOpenAIDocumentEmbedder component.
|
||||
@ -109,6 +110,9 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
|
||||
:param http_client_kwargs:
|
||||
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
|
||||
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
|
||||
:param raise_on_failure:
|
||||
Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
|
||||
and continue processing the remaining documents. If `True`, it will raise an exception on failure.
|
||||
"""
|
||||
# We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
|
||||
# with the API.
|
||||
@ -140,6 +144,7 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
|
||||
self.default_headers = default_headers or {}
|
||||
self.azure_ad_token_provider = azure_ad_token_provider
|
||||
self.http_client_kwargs = http_client_kwargs
|
||||
self.raise_on_failure = raise_on_failure
|
||||
|
||||
client_args: Dict[str, Any] = {
|
||||
"api_version": api_version,
|
||||
@ -191,6 +196,7 @@ class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
|
||||
default_headers=self.default_headers,
|
||||
azure_ad_token_provider=azure_ad_token_provider_name,
|
||||
http_client_kwargs=self.http_client_kwargs,
|
||||
raise_on_failure=self.raise_on_failure,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -39,7 +39,7 @@ class OpenAIDocumentEmbedder:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__( # pylint: disable=too-many-positional-arguments
|
||||
def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "text-embedding-ada-002",
|
||||
@ -55,6 +55,8 @@ class OpenAIDocumentEmbedder:
|
||||
timeout: Optional[float] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
http_client_kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
raise_on_failure: bool = False,
|
||||
):
|
||||
"""
|
||||
Creates an OpenAIDocumentEmbedder component.
|
||||
@ -100,6 +102,9 @@ class OpenAIDocumentEmbedder:
|
||||
:param http_client_kwargs:
|
||||
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
|
||||
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
|
||||
:param raise_on_failure:
|
||||
Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
|
||||
and continue processing the remaining documents. If `True`, it will raise an exception on failure.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
@ -115,6 +120,7 @@ class OpenAIDocumentEmbedder:
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.http_client_kwargs = http_client_kwargs
|
||||
self.raise_on_failure = raise_on_failure
|
||||
|
||||
if timeout is None:
|
||||
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
|
||||
@ -163,6 +169,7 @@ class OpenAIDocumentEmbedder:
|
||||
timeout=self.timeout,
|
||||
max_retries=self.max_retries,
|
||||
http_client_kwargs=self.http_client_kwargs,
|
||||
raise_on_failure=self.raise_on_failure,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -194,12 +201,14 @@ class OpenAIDocumentEmbedder:
|
||||
|
||||
return texts_to_embed
|
||||
|
||||
def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
|
||||
def _embed_batch(
|
||||
self, texts_to_embed: Dict[str, str], batch_size: int
|
||||
) -> Tuple[Dict[str, List[float]], Dict[str, Any]]:
|
||||
"""
|
||||
Embed a list of texts in batches.
|
||||
"""
|
||||
|
||||
all_embeddings = []
|
||||
doc_ids_to_embeddings: Dict[str, List[float]] = {}
|
||||
meta: Dict[str, Any] = {}
|
||||
for batch in tqdm(
|
||||
batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
||||
@ -215,10 +224,12 @@ class OpenAIDocumentEmbedder:
|
||||
ids = ", ".join(b[0] for b in batch)
|
||||
msg = "Failed embedding of documents {ids} caused by {exc}"
|
||||
logger.exception(msg, ids=ids, exc=exc)
|
||||
if self.raise_on_failure:
|
||||
raise exc
|
||||
continue
|
||||
|
||||
embeddings = [el.embedding for el in response.data]
|
||||
all_embeddings.extend(embeddings)
|
||||
doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings)))
|
||||
|
||||
if "model" not in meta:
|
||||
meta["model"] = response.model
|
||||
@ -228,16 +239,16 @@ class OpenAIDocumentEmbedder:
|
||||
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
|
||||
meta["usage"]["total_tokens"] += response.usage.total_tokens
|
||||
|
||||
return all_embeddings, meta
|
||||
return doc_ids_to_embeddings, meta
|
||||
|
||||
async def _embed_batch_async(
|
||||
self, texts_to_embed: Dict[str, str], batch_size: int
|
||||
) -> Tuple[List[List[float]], Dict[str, Any]]:
|
||||
) -> Tuple[Dict[str, List[float]], Dict[str, Any]]:
|
||||
"""
|
||||
Embed a list of texts in batches asynchronously.
|
||||
"""
|
||||
|
||||
all_embeddings = []
|
||||
doc_ids_to_embeddings: Dict[str, List[float]] = {}
|
||||
meta: Dict[str, Any] = {}
|
||||
|
||||
batches = list(batched(texts_to_embed.items(), batch_size))
|
||||
@ -256,10 +267,12 @@ class OpenAIDocumentEmbedder:
|
||||
ids = ", ".join(b[0] for b in batch)
|
||||
msg = "Failed embedding of documents {ids} caused by {exc}"
|
||||
logger.exception(msg, ids=ids, exc=exc)
|
||||
if self.raise_on_failure:
|
||||
raise exc
|
||||
continue
|
||||
|
||||
embeddings = [el.embedding for el in response.data]
|
||||
all_embeddings.extend(embeddings)
|
||||
doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings)))
|
||||
|
||||
if "model" not in meta:
|
||||
meta["model"] = response.model
|
||||
@ -269,7 +282,7 @@ class OpenAIDocumentEmbedder:
|
||||
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
|
||||
meta["usage"]["total_tokens"] += response.usage.total_tokens
|
||||
|
||||
return all_embeddings, meta
|
||||
return doc_ids_to_embeddings, meta
|
||||
|
||||
@component.output_types(documents=List[Document], meta=Dict[str, Any])
|
||||
def run(self, documents: List[Document]):
|
||||
@ -292,12 +305,13 @@ class OpenAIDocumentEmbedder:
|
||||
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
|
||||
|
||||
embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
||||
doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
||||
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
doc.embedding = emb
|
||||
doc_id_to_document = {doc.id: doc for doc in documents}
|
||||
for doc_id, emb in doc_ids_to_embeddings.items():
|
||||
doc_id_to_document[doc_id].embedding = emb
|
||||
|
||||
return {"documents": documents, "meta": meta}
|
||||
return {"documents": list(doc_id_to_document.values()), "meta": meta}
|
||||
|
||||
@component.output_types(documents=List[Document], meta=Dict[str, Any])
|
||||
async def run_async(self, documents: List[Document]):
|
||||
@ -320,9 +334,12 @@ class OpenAIDocumentEmbedder:
|
||||
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
|
||||
|
||||
embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
||||
doc_ids_to_embeddings, meta = await self._embed_batch_async(
|
||||
texts_to_embed=texts_to_embed, batch_size=self.batch_size
|
||||
)
|
||||
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
doc.embedding = emb
|
||||
doc_id_to_document = {doc.id: doc for doc in documents}
|
||||
for doc_id, emb in doc_ids_to_embeddings.items():
|
||||
doc_id_to_document[doc_id].embedding = emb
|
||||
|
||||
return {"documents": documents, "meta": meta}
|
||||
return {"documents": list(doc_id_to_document.values()), "meta": meta}
|
||||
|
@ -0,0 +1,7 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Added a raise_on_failure boolean parameter to OpenAIDocumentEmbedder and AzureOpenAIDocumentEmbedder. If set to True then the component will raise an exception when there is an error with the API request. It is set to False by default to so the previous behavior of logging an exception and continuing is still the default.
|
||||
fixes:
|
||||
- |
|
||||
Fix bug where if raise_on_failure=False and an error occurs mid-batch that the following embeddings would be paired with the wrong documents.
|
@ -78,6 +78,7 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
"default_headers": {},
|
||||
"azure_ad_token_provider": None,
|
||||
"http_client_kwargs": None,
|
||||
"raise_on_failure": False,
|
||||
},
|
||||
}
|
||||
|
||||
@ -95,6 +96,7 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
default_headers={"x-custom-header": "custom-value"},
|
||||
azure_ad_token_provider=default_azure_ad_token_provider,
|
||||
http_client_kwargs={"proxy": "http://example.com:3128", "verify": False},
|
||||
raise_on_failure=True,
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
@ -118,6 +120,7 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
"default_headers": {"x-custom-header": "custom-value"},
|
||||
"azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
|
||||
"http_client_kwargs": {"proxy": "http://example.com:3128", "verify": False},
|
||||
"raise_on_failure": True,
|
||||
},
|
||||
}
|
||||
|
||||
@ -144,6 +147,7 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
"default_headers": {},
|
||||
"azure_ad_token_provider": None,
|
||||
"http_client_kwargs": None,
|
||||
"raise_on_failure": False,
|
||||
},
|
||||
}
|
||||
component = AzureOpenAIDocumentEmbedder.from_dict(data)
|
||||
@ -157,6 +161,7 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
assert component.default_headers == {}
|
||||
assert component.azure_ad_token_provider is None
|
||||
assert component.http_client_kwargs is None
|
||||
assert component.raise_on_failure is False
|
||||
|
||||
def test_from_dict_with_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
@ -181,6 +186,7 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
"default_headers": {"x-custom-header": "custom-value"},
|
||||
"azure_ad_token_provider": "haystack.utils.azure.default_azure_ad_token_provider",
|
||||
"http_client_kwargs": {"proxy": "http://example.com:3128", "verify": False},
|
||||
"raise_on_failure": True,
|
||||
},
|
||||
}
|
||||
component = AzureOpenAIDocumentEmbedder.from_dict(data)
|
||||
@ -194,6 +200,7 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
assert component.default_headers == {"x-custom-header": "custom-value"}
|
||||
assert component.azure_ad_token_provider is not None
|
||||
assert component.http_client_kwargs == {"proxy": "http://example.com:3128", "verify": False}
|
||||
assert component.raise_on_failure is True
|
||||
|
||||
def test_embed_batch_handles_exceptions_gracefully(self, caplog):
|
||||
embedder = AzureOpenAIDocumentEmbedder(
|
||||
@ -215,6 +222,22 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
assert len(caplog.records) == 1
|
||||
assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.text
|
||||
|
||||
def test_embed_batch_raises_exception_on_failure(self):
|
||||
embedder = AzureOpenAIDocumentEmbedder(
|
||||
azure_endpoint="https://test.openai.azure.com",
|
||||
api_key=Secret.from_token("fake-api-key"),
|
||||
azure_deployment="text-embedding-ada-002",
|
||||
raise_on_failure=True,
|
||||
)
|
||||
fake_texts_to_embed = {"1": "text1", "2": "text2"}
|
||||
with patch.object(
|
||||
embedder.client.embeddings,
|
||||
"create",
|
||||
side_effect=APIError(message="Mocked error", request=Mock(), body=None),
|
||||
):
|
||||
with pytest.raises(APIError, match="Mocked error"):
|
||||
embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=2)
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("AZURE_OPENAI_API_KEY", None) and not os.environ.get("AZURE_OPENAI_ENDPOINT", None),
|
||||
|
@ -125,6 +125,7 @@ class TestOpenAIDocumentEmbedder:
|
||||
"embedding_separator": "\n",
|
||||
"timeout": None,
|
||||
"max_retries": None,
|
||||
"raise_on_failure": False,
|
||||
},
|
||||
}
|
||||
|
||||
@ -143,6 +144,7 @@ class TestOpenAIDocumentEmbedder:
|
||||
embedding_separator=" | ",
|
||||
timeout=10.0,
|
||||
max_retries=2,
|
||||
raise_on_failure=True,
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
@ -162,6 +164,7 @@ class TestOpenAIDocumentEmbedder:
|
||||
"embedding_separator": " | ",
|
||||
"timeout": 10.0,
|
||||
"max_retries": 2,
|
||||
"raise_on_failure": True,
|
||||
},
|
||||
}
|
||||
|
||||
@ -237,6 +240,45 @@ class TestOpenAIDocumentEmbedder:
|
||||
assert len(caplog.records) == 1
|
||||
assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.records[0].msg
|
||||
|
||||
def test_run_handles_exceptions_gracefully(self, caplog):
|
||||
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake_api_key"), batch_size=1)
|
||||
docs = [
|
||||
Document(content="I love cheese", meta={"topic": "Cuisine"}),
|
||||
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
|
||||
]
|
||||
|
||||
# Create a successful response for the second call
|
||||
successful_response = Mock()
|
||||
successful_response.data = [
|
||||
Mock(embedding=[0.4, 0.5, 0.6]) # Mock embedding for second doc
|
||||
]
|
||||
successful_response.model = "text-embedding-ada-002"
|
||||
successful_response.usage = {"prompt_tokens": 10, "total_tokens": 10}
|
||||
|
||||
with patch.object(
|
||||
embedder.client.embeddings,
|
||||
"create",
|
||||
side_effect=[
|
||||
APIError(message="Mocked error", request=Mock(), body=None), # First call fails
|
||||
successful_response, # Second call succeeds
|
||||
],
|
||||
):
|
||||
result = embedder.run(documents=docs)
|
||||
assert len(result["documents"]) == 2
|
||||
assert result["documents"][0].embedding is None
|
||||
assert result["documents"][1].embedding == [0.4, 0.5, 0.6]
|
||||
|
||||
def test_embed_batch_raises_exception_on_failure(self):
|
||||
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake_api_key"), raise_on_failure=True)
|
||||
fake_texts_to_embed = {"1": "text1", "2": "text2"}
|
||||
with patch.object(
|
||||
embedder.client.embeddings,
|
||||
"create",
|
||||
side_effect=APIError(message="Mocked error", request=Mock(), body=None),
|
||||
):
|
||||
with pytest.raises(APIError, match="Mocked error"):
|
||||
embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=2)
|
||||
|
||||
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set")
|
||||
@pytest.mark.integration
|
||||
def test_run(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user