feat: SentenceTransformersDocumentEmbedder (#5606)

* first draft

* incorporate feedback

* some unit tests

* release notes

* real release notes

* refactored to use a factory class

* allow forcing fresh instances

* first draft

* Update haystack/preview/embedding_backends/sentence_transformers_backend.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* simplify implementation and tests

* add embed_meta_fields implementation

* lg update

* improve meta data embedding; tests

* support non-string metadata

* make factory private

* change return type; improve tests

* warm_up not called in run

* fix typing

* rm unused import

* Remove base test class

* black

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
This commit is contained in:
Stefano Fiorucci 2023-08-28 17:23:41 +03:00 committed by GitHub
parent 89c1813d9f
commit 72fe4fc57b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 226 additions and 0 deletions

View File

@ -0,0 +1,101 @@
from typing import List, Optional, Union
from haystack.preview import component
from haystack.preview import Document
from haystack.preview.embedding_backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@component
class SentenceTransformersDocumentEmbedder:
"""
A component for computing Document embeddings using Sentence Transformers models.
The embedding of each Document is stored in the `embedding` field of the Document.
"""
def __init__(
self,
model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
use_auth_token: Union[bool, str, None] = None,
batch_size: int = 32,
progress_bar: bool = True,
normalize_embeddings: bool = False,
metadata_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Create a SentenceTransformersDocumentEmbedder component.
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``.
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. If None, checks if a GPU can be used.
:param use_auth_token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
:param batch_size: Number of strings to encode at once.
:param progress_bar: If true, displays progress bar during embedding.
:param normalize_embeddings: If set to true, returned vectors will have length 1.
:param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
self.model_name_or_path = model_name_or_path
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device
self.use_auth_token = use_auth_token
self.batch_size = batch_size
self.progress_bar = progress_bar
self.normalize_embeddings = normalize_embeddings
self.metadata_fields_to_embed = metadata_fields_to_embed or []
self.embedding_separator = embedding_separator
def warm_up(self):
"""
Load the embedding backend.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.use_auth_token
)
@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Embed a list of Documents.
The embedding of each Document is stored in the `embedding` field of the Document.
"""
if not isinstance(documents, list) or not isinstance(documents[0], Document):
raise TypeError(
"SentenceTransformersDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder."
)
if not hasattr(self, "embedding_backend"):
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
# TODO: once non textual Documents are properly supported, we should also prepare them for embedding here
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.metadata[key])
for key in self.metadata_fields_to_embed
if key in doc.metadata and doc.metadata[key]
]
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content])
texts_to_embed.append(text_to_embed)
embeddings = self.embedding_backend.embed(
texts_to_embed,
batch_size=self.batch_size,
show_progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
)
documents_with_embeddings = []
for doc, emb in zip(documents, embeddings):
doc_as_dict = doc.to_dict()
doc_as_dict["embedding"] = emb
documents_with_embeddings.append(Document.from_dict(doc_as_dict))
return {"documents": documents_with_embeddings}

View File

@ -0,0 +1,5 @@
---
preview:
- |
Add Sentence Transformers Document Embedder.
It computes embeddings of Documents. The embedding of each Document is stored in the `embedding` field of the Document.

View File

@ -0,0 +1,120 @@
from unittest.mock import patch, MagicMock
import pytest
import numpy as np
from haystack.preview import Document
from haystack.preview.components.embedders.sentence_transformers_document_embedder import (
SentenceTransformersDocumentEmbedder,
)
class TestSentenceTransformersDocumentEmbedder:
@pytest.mark.unit
def test_init_default(self):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
assert embedder.model_name_or_path == "model"
assert embedder.device is None
assert embedder.use_auth_token is None
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False
@pytest.mark.unit
def test_init_with_parameters(self):
embedder = SentenceTransformersDocumentEmbedder(
model_name_or_path="model",
device="cpu",
use_auth_token=True,
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
)
assert embedder.model_name_or_path == "model"
assert embedder.device == "cpu"
assert embedder.use_auth_token is True
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.normalize_embeddings is True
@pytest.mark.unit
@patch(
"haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup(self, mocked_factory):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name_or_path="model", device=None, use_auth_token=None
)
@pytest.mark.unit
@patch(
"haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup_doesnt_reload(self, mocked_factory):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once()
@pytest.mark.unit
def test_run(self):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
embedder.embedding_backend = MagicMock()
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist()
documents = [Document(content=f"document number {i}") for i in range(5)]
result = embedder.run(documents=documents)
assert isinstance(result["documents"], list)
assert len(result["documents"]) == len(documents)
for doc in result["documents"]:
assert isinstance(doc, Document)
assert isinstance(doc.embedding, list)
assert isinstance(doc.embedding[0], float)
@pytest.mark.unit
def test_run_wrong_input_format(self):
embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="model")
string_input = "text"
list_integers_input = [1, 2, 3]
with pytest.raises(
TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
):
embedder.run(documents=string_input)
with pytest.raises(
TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
):
embedder.run(documents=list_integers_input)
@pytest.mark.unit
def test_embed_metadata(self):
embedder = SentenceTransformersDocumentEmbedder(
model_name_or_path="model", metadata_fields_to_embed=["meta_field"], embedding_separator="\n"
)
embedder.embedding_backend = MagicMock()
documents = [
Document(content=f"document number {i}", metadata={"meta_field": f"meta_value {i}"}) for i in range(5)
]
embedder.run(documents=documents)
embedder.embedding_backend.embed.assert_called_once_with(
[
"meta_value 0\ndocument number 0",
"meta_value 1\ndocument number 1",
"meta_value 2\ndocument number 2",
"meta_value 3\ndocument number 3",
"meta_value 4\ndocument number 4",
],
batch_size=32,
show_progress_bar=True,
normalize_embeddings=False,
)