mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-12 07:17:41 +00:00
feat: OpenAIDocumentEmbedder (#5822)
* first draft * release note * mypy fix * fix test * corrections * pr feedback * better secrets handling and new tests * missing imports in embedders/__init__.py * better format condition * address feedback
This commit is contained in:
parent
83724b74e3
commit
d4aacad5f9
@ -2,5 +2,12 @@ from haystack.preview.components.embedders.sentence_transformers_text_embedder i
|
||||
from haystack.preview.components.embedders.sentence_transformers_document_embedder import (
|
||||
SentenceTransformersDocumentEmbedder,
|
||||
)
|
||||
from haystack.preview.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
|
||||
from haystack.preview.components.embedders.openai_text_embedder import OpenAITextEmbedder
|
||||
|
||||
__all__ = ["SentenceTransformersTextEmbedder", "SentenceTransformersDocumentEmbedder"]
|
||||
__all__ = [
|
||||
"SentenceTransformersTextEmbedder",
|
||||
"SentenceTransformersDocumentEmbedder",
|
||||
"OpenAITextEmbedder",
|
||||
"OpenAIDocumentEmbedder",
|
||||
]
|
||||
|
||||
@ -0,0 +1,164 @@
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
import os
|
||||
|
||||
import openai
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||
|
||||
|
||||
@component
|
||||
class OpenAIDocumentEmbedder:
|
||||
"""
|
||||
A component for computing Document embeddings using OpenAI models.
|
||||
The embedding of each Document is stored in the `embedding` field of the Document.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
organization: Optional[str] = None,
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
batch_size: int = 32,
|
||||
progress_bar: bool = True,
|
||||
metadata_fields_to_embed: Optional[List[str]] = None,
|
||||
embedding_separator: str = "\n",
|
||||
):
|
||||
"""
|
||||
Create a OpenAIDocumentEmbedder component.
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
environment variable OPENAI_API_KEY (recommended).
|
||||
:param model_name: The name of the model to use.
|
||||
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
|
||||
:param organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI
|
||||
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
|
||||
:param prefix: A string to add to the beginning of each text.
|
||||
:param suffix: A string to add to the end of each text.
|
||||
:param batch_size: Number of Documents to encode at once.
|
||||
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments
|
||||
to keep the logs clean.
|
||||
:param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text.
|
||||
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
|
||||
"""
|
||||
|
||||
if api_key is None:
|
||||
try:
|
||||
api_key = os.environ["OPENAI_API_KEY"]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"OpenAIDocumentEmbedder expects an OpenAI API key. "
|
||||
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
|
||||
) from e
|
||||
|
||||
self.model_name = model_name
|
||||
self.organization = organization
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
self.metadata_fields_to_embed = metadata_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
|
||||
openai.api_key = api_key
|
||||
if organization is not None:
|
||||
openai.organization = organization
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
|
||||
to the constructor.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
model_name=self.model_name,
|
||||
organization=self.organization,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
batch_size=self.batch_size,
|
||||
progress_bar=self.progress_bar,
|
||||
metadata_fields_to_embed=self.metadata_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIDocumentEmbedder":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
||||
"""
|
||||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
|
||||
"""
|
||||
texts_to_embed = []
|
||||
for doc in documents:
|
||||
meta_values_to_embed = [
|
||||
str(doc.metadata[key])
|
||||
for key in self.metadata_fields_to_embed
|
||||
if key in doc.metadata and doc.metadata[key] is not None
|
||||
]
|
||||
|
||||
text_to_embed = (
|
||||
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.text or ""]) + self.suffix
|
||||
)
|
||||
|
||||
# copied from OpenAI embedding_utils (https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py)
|
||||
# replace newlines, which can negatively affect performance.
|
||||
text_to_embed = text_to_embed.replace("\n", " ")
|
||||
texts_to_embed.append(text_to_embed)
|
||||
return texts_to_embed
|
||||
|
||||
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[str], Dict[str, Any]]:
|
||||
"""
|
||||
Embed a list of texts in batches.
|
||||
"""
|
||||
|
||||
all_embeddings = []
|
||||
metadata = {}
|
||||
for i in tqdm(
|
||||
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
||||
):
|
||||
batch = texts_to_embed[i : i + batch_size]
|
||||
response = openai.Embedding.create(model=self.model_name, input=batch)
|
||||
embeddings = [el["embedding"] for el in response.data]
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
if "model" not in metadata:
|
||||
metadata["model"] = response.model
|
||||
if "usage" not in metadata:
|
||||
metadata["usage"] = dict(response.usage.items())
|
||||
else:
|
||||
metadata["usage"]["prompt_tokens"] += response.usage.prompt_tokens
|
||||
metadata["usage"]["total_tokens"] += response.usage.total_tokens
|
||||
|
||||
return all_embeddings, metadata
|
||||
|
||||
@component.output_types(documents=List[Document], metadata=Dict[str, Any])
|
||||
def run(self, documents: List[Document]):
|
||||
"""
|
||||
Embed a list of Documents.
|
||||
The embedding of each Document is stored in the `embedding` field of the Document.
|
||||
|
||||
:param documents: A list of Documents to embed.
|
||||
"""
|
||||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
|
||||
raise TypeError(
|
||||
"OpenAIDocumentEmbedder expects a list of Documents as input."
|
||||
"In case you want to embed a string, please use the OpenAITextEmbedder."
|
||||
)
|
||||
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
|
||||
|
||||
embeddings, metadata = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
||||
|
||||
documents_with_embeddings = []
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
doc_as_dict = doc.to_dict()
|
||||
doc_as_dict["embedding"] = emb
|
||||
documents_with_embeddings.append(Document.from_dict(doc_as_dict))
|
||||
|
||||
return {"documents": documents_with_embeddings, "metadata": metadata}
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Add OpenAI Document Embedder.
|
||||
It computes embeddings of Documents using OpenAI models.
|
||||
The embedding of each Document is stored in the `embedding` field of the Document.
|
||||
@ -0,0 +1,334 @@
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import openai
|
||||
from openai.util import convert_to_openai_object
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
|
||||
|
||||
|
||||
def mock_openai_response(
|
||||
input: List[str], model: str = "text-embedding-ada-002", **kwargs
|
||||
) -> openai.openai_object.OpenAIObject:
|
||||
dict_response = {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"object": "embedding", "index": i, "embedding": np.random.rand(1536).tolist()} for i in range(len(input))
|
||||
],
|
||||
"model": model,
|
||||
"usage": {"prompt_tokens": 4, "total_tokens": 4},
|
||||
}
|
||||
|
||||
return convert_to_openai_object(dict_response)
|
||||
|
||||
|
||||
class TestOpenAIDocumentEmbedder:
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||
embedder = OpenAIDocumentEmbedder()
|
||||
|
||||
assert openai.api_key == "fake-api-key"
|
||||
|
||||
assert embedder.model_name == "text-embedding-ada-002"
|
||||
assert embedder.organization is None
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.metadata_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
embedder = OpenAIDocumentEmbedder(
|
||||
api_key="fake-api-key",
|
||||
model_name="model",
|
||||
organization="my-org",
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
metadata_fields_to_embed=["test_field"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
assert openai.api_key == "fake-api-key"
|
||||
assert openai.organization == "my-org"
|
||||
|
||||
assert embedder.organization == "my-org"
|
||||
assert embedder.model_name == "model"
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.batch_size == 64
|
||||
assert embedder.progress_bar is False
|
||||
assert embedder.metadata_fields_to_embed == ["test_field"]
|
||||
assert embedder.embedding_separator == " | "
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="OpenAIDocumentEmbedder expects an OpenAI API key"):
|
||||
OpenAIDocumentEmbedder()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = OpenAIDocumentEmbedder(api_key="fake-api-key")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "OpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name": "text-embedding-ada-002",
|
||||
"organization": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"metadata_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = OpenAIDocumentEmbedder(
|
||||
api_key="fake-api-key",
|
||||
model_name="model",
|
||||
organization="my-org",
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
metadata_fields_to_embed=["test_field"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "OpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name": "model",
|
||||
"organization": "my-org",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"metadata_fields_to_embed": ["test_field"],
|
||||
"embedding_separator": " | ",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||
data = {
|
||||
"type": "OpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name": "model",
|
||||
"organization": "my-org",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"metadata_fields_to_embed": ["test_field"],
|
||||
"embedding_separator": " | ",
|
||||
},
|
||||
}
|
||||
component = OpenAIDocumentEmbedder.from_dict(data)
|
||||
assert openai.api_key == "fake-api-key"
|
||||
assert component.model_name == "model"
|
||||
assert component.organization == "my-org"
|
||||
assert openai.organization == "my-org"
|
||||
assert component.prefix == "prefix"
|
||||
assert component.suffix == "suffix"
|
||||
assert component.batch_size == 64
|
||||
assert component.progress_bar is False
|
||||
assert component.metadata_fields_to_embed == ["test_field"]
|
||||
assert component.embedding_separator == " | "
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_fail_wo_env_var(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
data = {
|
||||
"type": "OpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name": "model",
|
||||
"organization": "my-org",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"metadata_fields_to_embed": ["test_field"],
|
||||
"embedding_separator": " | ",
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError, match="OpenAIDocumentEmbedder expects an OpenAI API key"):
|
||||
OpenAIDocumentEmbedder.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prepare_texts_to_embed_w_metadata(self):
|
||||
documents = [
|
||||
Document(text=f"document number {i}:\ncontent", metadata={"meta_field": f"meta_value {i}"})
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
embedder = OpenAIDocumentEmbedder(
|
||||
api_key="fake-api-key", metadata_fields_to_embed=["meta_field"], embedding_separator=" | "
|
||||
)
|
||||
|
||||
prepared_texts = embedder._prepare_texts_to_embed(documents)
|
||||
|
||||
# note that newline is replaced by space
|
||||
assert prepared_texts == [
|
||||
"meta_value 0 | document number 0: content",
|
||||
"meta_value 1 | document number 1: content",
|
||||
"meta_value 2 | document number 2: content",
|
||||
"meta_value 3 | document number 3: content",
|
||||
"meta_value 4 | document number 4: content",
|
||||
]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prepare_texts_to_embed_w_suffix(self):
|
||||
documents = [Document(text=f"document number {i}") for i in range(5)]
|
||||
|
||||
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key", prefix="my_prefix ", suffix=" my_suffix")
|
||||
|
||||
prepared_texts = embedder._prepare_texts_to_embed(documents)
|
||||
|
||||
assert prepared_texts == [
|
||||
"my_prefix document number 0 my_suffix",
|
||||
"my_prefix document number 1 my_suffix",
|
||||
"my_prefix document number 2 my_suffix",
|
||||
"my_prefix document number 3 my_suffix",
|
||||
"my_prefix document number 4 my_suffix",
|
||||
]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embed_batch(self):
|
||||
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
|
||||
|
||||
with patch(
|
||||
"haystack.preview.components.embedders.openai_document_embedder.openai.Embedding"
|
||||
) as openai_embedding_patch:
|
||||
openai_embedding_patch.create.side_effect = mock_openai_response
|
||||
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key", model_name="model")
|
||||
|
||||
embeddings, metadata = embedder._embed_batch(texts_to_embed=texts, batch_size=2)
|
||||
|
||||
assert openai_embedding_patch.create.call_count == 3
|
||||
|
||||
assert isinstance(embeddings, list)
|
||||
assert len(embeddings) == len(texts)
|
||||
for embedding in embeddings:
|
||||
assert isinstance(embedding, list)
|
||||
assert len(embedding) == 1536
|
||||
assert all(isinstance(x, float) for x in embedding)
|
||||
|
||||
# openai.Embedding.create is called 3 times
|
||||
assert metadata == {"model": "model", "usage": {"prompt_tokens": 3 * 4, "total_tokens": 3 * 4}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self):
|
||||
docs = [
|
||||
Document(text="I love cheese", metadata={"topic": "Cuisine"}),
|
||||
Document(text="A transformer is a deep learning architecture", metadata={"topic": "ML"}),
|
||||
]
|
||||
|
||||
model = "text-similarity-ada-001"
|
||||
with patch(
|
||||
"haystack.preview.components.embedders.openai_document_embedder.openai.Embedding"
|
||||
) as openai_embedding_patch:
|
||||
openai_embedding_patch.create.side_effect = mock_openai_response
|
||||
embedder = OpenAIDocumentEmbedder(
|
||||
api_key="fake-api-key",
|
||||
model_name=model,
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
metadata_fields_to_embed=["topic"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
|
||||
result = embedder.run(documents=docs)
|
||||
|
||||
openai_embedding_patch.create.assert_called_once_with(
|
||||
model=model,
|
||||
input=[
|
||||
"prefix Cuisine | I love cheese suffix",
|
||||
"prefix ML | A transformer is a deep learning architecture suffix",
|
||||
],
|
||||
)
|
||||
documents_with_embeddings = result["documents"]
|
||||
metadata = result["metadata"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 1536
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
assert metadata == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_custom_batch_size(self):
|
||||
docs = [
|
||||
Document(text="I love cheese", metadata={"topic": "Cuisine"}),
|
||||
Document(text="A transformer is a deep learning architecture", metadata={"topic": "ML"}),
|
||||
]
|
||||
|
||||
model = "text-similarity-ada-001"
|
||||
with patch(
|
||||
"haystack.preview.components.embedders.openai_document_embedder.openai.Embedding"
|
||||
) as openai_embedding_patch:
|
||||
openai_embedding_patch.create.side_effect = mock_openai_response
|
||||
embedder = OpenAIDocumentEmbedder(
|
||||
api_key="fake-api-key",
|
||||
model_name=model,
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
metadata_fields_to_embed=["topic"],
|
||||
embedding_separator=" | ",
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
result = embedder.run(documents=docs)
|
||||
|
||||
assert openai_embedding_patch.create.call_count == 2
|
||||
|
||||
documents_with_embeddings = result["documents"]
|
||||
metadata = result["metadata"]
|
||||
|
||||
assert isinstance(documents_with_embeddings, list)
|
||||
assert len(documents_with_embeddings) == len(docs)
|
||||
for doc in documents_with_embeddings:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert len(doc.embedding) == 1536
|
||||
assert all(isinstance(x, float) for x in doc.embedding)
|
||||
|
||||
# openai.Embedding.create is called 2 times
|
||||
assert metadata == {"model": model, "usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_wrong_input_format(self):
|
||||
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key")
|
||||
|
||||
# wrong formats
|
||||
string_input = "text"
|
||||
list_integers_input = [1, 2, 3]
|
||||
|
||||
with pytest.raises(TypeError, match="OpenAIDocumentEmbedder expects a list of Documents as input"):
|
||||
embedder.run(documents=string_input)
|
||||
|
||||
with pytest.raises(TypeError, match="OpenAIDocumentEmbedder expects a list of Documents as input"):
|
||||
embedder.run(documents=list_integers_input)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_on_empty_list(self):
|
||||
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key")
|
||||
|
||||
empty_list_input = []
|
||||
result = embedder.run(documents=empty_list_input)
|
||||
|
||||
assert result["documents"] is not None
|
||||
assert not result["documents"] # empty list
|
||||
Loading…
x
Reference in New Issue
Block a user