mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
feat: SentenceTransformersDocumentEmbedder (#5606)
* first draft * incorporate feedback * some unit tests * release notes * real release notes * refactored to use a factory class * allow forcing fresh instances * first draft * Update haystack/preview/embedding_backends/sentence_transformers_backend.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * simplify implementation and tests * add embed_meta_fields implementation * lg update * improve meta data embedding; tests * support non-string metadata * make factory private * change return type; improve tests * warm_up not called in run * fix typing * rm unused import * Remove base test class * black --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
parent
89c1813d9f
commit
72fe4fc57b
@ -0,0 +1,101 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from haystack.preview import component
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.embedding_backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
|
||||
|
||||
@component
|
||||
class SentenceTransformersDocumentEmbedder:
|
||||
"""
|
||||
A component for computing Document embeddings using Sentence Transformers models.
|
||||
The embedding of each Document is stored in the `embedding` field of the Document.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
device: Optional[str] = None,
|
||||
use_auth_token: Union[bool, str, None] = None,
|
||||
batch_size: int = 32,
|
||||
progress_bar: bool = True,
|
||||
normalize_embeddings: bool = False,
|
||||
metadata_fields_to_embed: Optional[List[str]] = None,
|
||||
embedding_separator: str = "\n",
|
||||
):
|
||||
"""
|
||||
Create a SentenceTransformersDocumentEmbedder component.
|
||||
|
||||
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``.
|
||||
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
|
||||
:param use_auth_token: The API token used to download private models from Hugging Face.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
:param batch_size: Number of strings to encode at once.
|
||||
:param progress_bar: If true, displays progress bar during embedding.
|
||||
:param normalize_embeddings: If set to true, returned vectors will have length 1.
|
||||
:param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document content.
|
||||
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
|
||||
"""
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
# TODO: remove device parameter and use Haystack's device management once migrated
|
||||
self.device = device
|
||||
self.use_auth_token = use_auth_token
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
self.metadata_fields_to_embed = metadata_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Load the embedding backend.
|
||||
"""
|
||||
if not hasattr(self, "embedding_backend"):
|
||||
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||
model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.use_auth_token
|
||||
)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document]):
|
||||
"""
|
||||
Embed a list of Documents.
|
||||
The embedding of each Document is stored in the `embedding` field of the Document.
|
||||
"""
|
||||
if not isinstance(documents, list) or not isinstance(documents[0], Document):
|
||||
raise TypeError(
|
||||
"SentenceTransformersDocumentEmbedder expects a list of Documents as input."
|
||||
"In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder."
|
||||
)
|
||||
if not hasattr(self, "embedding_backend"):
|
||||
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
|
||||
|
||||
# TODO: once non textual Documents are properly supported, we should also prepare them for embedding here
|
||||
|
||||
texts_to_embed = []
|
||||
for doc in documents:
|
||||
meta_values_to_embed = [
|
||||
str(doc.metadata[key])
|
||||
for key in self.metadata_fields_to_embed
|
||||
if key in doc.metadata and doc.metadata[key]
|
||||
]
|
||||
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content])
|
||||
texts_to_embed.append(text_to_embed)
|
||||
|
||||
embeddings = self.embedding_backend.embed(
|
||||
texts_to_embed,
|
||||
batch_size=self.batch_size,
|
||||
show_progress_bar=self.progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
)
|
||||
|
||||
documents_with_embeddings = []
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
doc_as_dict = doc.to_dict()
|
||||
doc_as_dict["embedding"] = emb
|
||||
documents_with_embeddings.append(Document.from_dict(doc_as_dict))
|
||||
|
||||
return {"documents": documents_with_embeddings}
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Add Sentence Transformers Document Embedder.
|
||||
It computes embeddings of Documents. The embedding of each Document is stored in the `embedding` field of the Document.
|
||||
@ -0,0 +1,120 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from haystack.preview import Document
|
||||
from haystack.preview.components.embedders.sentence_transformers_document_embedder import (
|
||||
SentenceTransformersDocumentEmbedder,
|
||||
)
|
||||
|
||||
|
||||
class TestSentenceTransformersDocumentEmbedder:
|
||||
@pytest.mark.unit
|
||||
def test_init_default(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
|
||||
assert embedder.model_name_or_path == "model"
|
||||
assert embedder.device is None
|
||||
assert embedder.use_auth_token is None
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.normalize_embeddings is False
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_parameters(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(
|
||||
model_name_or_path="model",
|
||||
device="cpu",
|
||||
use_auth_token=True,
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
assert embedder.model_name_or_path == "model"
|
||||
assert embedder.device == "cpu"
|
||||
assert embedder.use_auth_token is True
|
||||
assert embedder.batch_size == 64
|
||||
assert embedder.progress_bar is False
|
||||
assert embedder.normalize_embeddings is True
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch(
|
||||
"haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
)
|
||||
def test_warmup(self, mocked_factory):
|
||||
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
|
||||
mocked_factory.get_embedding_backend.assert_not_called()
|
||||
embedder.warm_up()
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(
|
||||
model_name_or_path="model", device=None, use_auth_token=None
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch(
|
||||
"haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
)
|
||||
def test_warmup_doesnt_reload(self, mocked_factory):
|
||||
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
|
||||
mocked_factory.get_embedding_backend.assert_not_called()
|
||||
embedder.warm_up()
|
||||
embedder.warm_up()
|
||||
mocked_factory.get_embedding_backend.assert_called_once()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
|
||||
embedder.embedding_backend = MagicMock()
|
||||
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist()
|
||||
|
||||
documents = [Document(content=f"document number {i}") for i in range(5)]
|
||||
|
||||
result = embedder.run(documents=documents)
|
||||
|
||||
assert isinstance(result["documents"], list)
|
||||
assert len(result["documents"]) == len(documents)
|
||||
for doc in result["documents"]:
|
||||
assert isinstance(doc, Document)
|
||||
assert isinstance(doc.embedding, list)
|
||||
assert isinstance(doc.embedding[0], float)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_wrong_input_format(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
|
||||
|
||||
string_input = "text"
|
||||
list_integers_input = [1, 2, 3]
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
|
||||
):
|
||||
embedder.run(documents=string_input)
|
||||
|
||||
with pytest.raises(
|
||||
TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
|
||||
):
|
||||
embedder.run(documents=list_integers_input)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_embed_metadata(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(
|
||||
model_name_or_path="model", metadata_fields_to_embed=["meta_field"], embedding_separator="\n"
|
||||
)
|
||||
embedder.embedding_backend = MagicMock()
|
||||
|
||||
documents = [
|
||||
Document(content=f"document number {i}", metadata={"meta_field": f"meta_value {i}"}) for i in range(5)
|
||||
]
|
||||
|
||||
embedder.run(documents=documents)
|
||||
|
||||
embedder.embedding_backend.embed.assert_called_once_with(
|
||||
[
|
||||
"meta_value 0\ndocument number 0",
|
||||
"meta_value 1\ndocument number 1",
|
||||
"meta_value 2\ndocument number 2",
|
||||
"meta_value 3\ndocument number 3",
|
||||
"meta_value 4\ndocument number 4",
|
||||
],
|
||||
batch_size=32,
|
||||
show_progress_bar=True,
|
||||
normalize_embeddings=False,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user