feat: add prefix and suffix to SentenceTransformersDocumentEmbedder (#5745)

* add prefix and suffix

* fix test
This commit is contained in:
Stefano Fiorucci 2023-09-13 12:55:06 +02:00 committed by GitHub
parent 335a09bc1d
commit 283ecf2760
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 1 deletions

View File

@ -0,0 +1,6 @@
from haystack.preview.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
from haystack.preview.components.embedders.sentence_transformers_document_embedder import (
SentenceTransformersDocumentEmbedder,
)
__all__ = ["SentenceTransformersTextEmbedder", "SentenceTransformersDocumentEmbedder"]

View File

@ -18,6 +18,8 @@ class SentenceTransformersDocumentEmbedder:
model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
use_auth_token: Union[bool, str, None] = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
progress_bar: bool = True,
normalize_embeddings: bool = False,
@ -32,6 +34,8 @@ class SentenceTransformersDocumentEmbedder:
:param use_auth_token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
:param prefix: A string to add to the beginning of each Document text before embedding.
:param suffix: A string to add to the end of each Document text before embedding.
:param batch_size: Number of strings to encode at once.
:param progress_bar: If true, displays progress bar during embedding.
:param normalize_embeddings: If set to true, returned vectors will have length 1.
@ -43,6 +47,8 @@ class SentenceTransformersDocumentEmbedder:
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.use_auth_token = use_auth_token
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.normalize_embeddings = normalize_embeddings
@ -58,6 +64,8 @@ class SentenceTransformersDocumentEmbedder:
model_name_or_path=self.model_name_or_path,
device=self.device,
use_auth_token=self.use_auth_token,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
@ -104,7 +112,9 @@ class SentenceTransformersDocumentEmbedder:
for key in self.metadata_fields_to_embed
if key in doc.metadata and doc.metadata[key]
]
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.text or ""])
text_to_embed = (
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.text or ""]) + self.suffix
)
texts_to_embed.append(text_to_embed)
embeddings = self.embedding_backend.embed(

View File

@ -0,0 +1,7 @@
---
preview:
- |
Add `prefix` and `suffix` attributes to `SentenceTransformersDocumentEmbedder`.
They can be used to add a prefix and suffix to the Document text before
embedding it. This is necessary to take full advantage of modern embedding
models, such as E5.

View File

@ -15,6 +15,8 @@ class TestSentenceTransformersDocumentEmbedder:
assert embedder.model_name_or_path == "model"
assert embedder.device == "cpu"
assert embedder.use_auth_token is None
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False
@ -27,6 +29,8 @@ class TestSentenceTransformersDocumentEmbedder:
model_name_or_path="model",
device="cuda",
use_auth_token=True,
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
@ -36,6 +40,8 @@ class TestSentenceTransformersDocumentEmbedder:
assert embedder.model_name_or_path == "model"
assert embedder.device == "cuda"
assert embedder.use_auth_token is True
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.normalize_embeddings is True
@ -52,6 +58,8 @@ class TestSentenceTransformersDocumentEmbedder:
"model_name_or_path": "model",
"device": "cpu",
"use_auth_token": None,
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
@ -66,6 +74,8 @@ class TestSentenceTransformersDocumentEmbedder:
model_name_or_path="model",
device="cuda",
use_auth_token="the-token",
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
@ -79,6 +89,8 @@ class TestSentenceTransformersDocumentEmbedder:
"model_name_or_path": "model",
"device": "cuda",
"use_auth_token": "the-token",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
@ -95,6 +107,8 @@ class TestSentenceTransformersDocumentEmbedder:
"model_name_or_path": "model",
"device": "cuda",
"use_auth_token": "the-token",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": False,
@ -106,6 +120,8 @@ class TestSentenceTransformersDocumentEmbedder:
assert component.model_name_or_path == "model"
assert component.device == "cuda"
assert component.use_auth_token == "the-token"
assert component.prefix == "prefix"
assert component.suffix == "suffix"
assert component.batch_size == 64
assert component.progress_bar is False
assert component.normalize_embeddings is False
@ -194,3 +210,33 @@ class TestSentenceTransformersDocumentEmbedder:
show_progress_bar=True,
normalize_embeddings=False,
)
@pytest.mark.unit
def test_prefix_suffix(self):
embedder = SentenceTransformersDocumentEmbedder(
model_name_or_path="model",
prefix="my_prefix ",
suffix=" my_suffix",
metadata_fields_to_embed=["meta_field"],
embedding_separator="\n",
)
embedder.embedding_backend = MagicMock()
documents = [
Document(text=f"document number {i}", metadata={"meta_field": f"meta_value {i}"}) for i in range(5)
]
embedder.run(documents=documents)
embedder.embedding_backend.embed.assert_called_once_with(
[
"my_prefix meta_value 0\ndocument number 0 my_suffix",
"my_prefix meta_value 1\ndocument number 1 my_suffix",
"my_prefix meta_value 2\ndocument number 2 my_suffix",
"my_prefix meta_value 3\ndocument number 3 my_suffix",
"my_prefix meta_value 4\ndocument number 4 my_suffix",
],
batch_size=32,
show_progress_bar=True,
normalize_embeddings=False,
)