mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 20:46:31 +00:00
feat: add prefix and suffix to SentenceTransformersDocumentEmbedder (#5745)
* add prefix and suffix * fix test
This commit is contained in:
parent
335a09bc1d
commit
283ecf2760
@ -0,0 +1,6 @@
|
||||
from haystack.preview.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
|
||||
from haystack.preview.components.embedders.sentence_transformers_document_embedder import (
|
||||
SentenceTransformersDocumentEmbedder,
|
||||
)
|
||||
|
||||
__all__ = ["SentenceTransformersTextEmbedder", "SentenceTransformersDocumentEmbedder"]
|
||||
@ -18,6 +18,8 @@ class SentenceTransformersDocumentEmbedder:
|
||||
model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
device: Optional[str] = None,
|
||||
use_auth_token: Union[bool, str, None] = None,
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
batch_size: int = 32,
|
||||
progress_bar: bool = True,
|
||||
normalize_embeddings: bool = False,
|
||||
@ -32,6 +34,8 @@ class SentenceTransformersDocumentEmbedder:
|
||||
:param use_auth_token: The API token used to download private models from Hugging Face.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
:param prefix: A string to add to the beginning of each Document text before embedding.
|
||||
:param suffix: A string to add to the end of each Document text before embedding.
|
||||
:param batch_size: Number of strings to encode at once.
|
||||
:param progress_bar: If true, displays progress bar during embedding.
|
||||
:param normalize_embeddings: If set to true, returned vectors will have length 1.
|
||||
@ -43,6 +47,8 @@ class SentenceTransformersDocumentEmbedder:
|
||||
# TODO: remove device parameter and use Haystack's device management once migrated
|
||||
self.device = device or "cpu"
|
||||
self.use_auth_token = use_auth_token
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.batch_size = batch_size
|
||||
self.progress_bar = progress_bar
|
||||
self.normalize_embeddings = normalize_embeddings
|
||||
@ -58,6 +64,8 @@ class SentenceTransformersDocumentEmbedder:
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
device=self.device,
|
||||
use_auth_token=self.use_auth_token,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
batch_size=self.batch_size,
|
||||
progress_bar=self.progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
@ -104,7 +112,9 @@ class SentenceTransformersDocumentEmbedder:
|
||||
for key in self.metadata_fields_to_embed
|
||||
if key in doc.metadata and doc.metadata[key]
|
||||
]
|
||||
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.text or ""])
|
||||
text_to_embed = (
|
||||
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.text or ""]) + self.suffix
|
||||
)
|
||||
texts_to_embed.append(text_to_embed)
|
||||
|
||||
embeddings = self.embedding_backend.embed(
|
||||
|
||||
@ -0,0 +1,7 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Add `prefix` and `suffix` attributes to `SentenceTransformersDocumentEmbedder`.
|
||||
They can be used to add a prefix and suffix to the Document text before
|
||||
embedding it. This is necessary to take full advantage of modern embedding
|
||||
models, such as E5.
|
||||
@ -15,6 +15,8 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert embedder.model_name_or_path == "model"
|
||||
assert embedder.device == "cpu"
|
||||
assert embedder.use_auth_token is None
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.batch_size == 32
|
||||
assert embedder.progress_bar is True
|
||||
assert embedder.normalize_embeddings is False
|
||||
@ -27,6 +29,8 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
model_name_or_path="model",
|
||||
device="cuda",
|
||||
use_auth_token=True,
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
normalize_embeddings=True,
|
||||
@ -36,6 +40,8 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert embedder.model_name_or_path == "model"
|
||||
assert embedder.device == "cuda"
|
||||
assert embedder.use_auth_token is True
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.batch_size == 64
|
||||
assert embedder.progress_bar is False
|
||||
assert embedder.normalize_embeddings is True
|
||||
@ -52,6 +58,8 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"model_name_or_path": "model",
|
||||
"device": "cpu",
|
||||
"use_auth_token": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"normalize_embeddings": False,
|
||||
@ -66,6 +74,8 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
model_name_or_path="model",
|
||||
device="cuda",
|
||||
use_auth_token="the-token",
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
progress_bar=False,
|
||||
normalize_embeddings=True,
|
||||
@ -79,6 +89,8 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"model_name_or_path": "model",
|
||||
"device": "cuda",
|
||||
"use_auth_token": "the-token",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"normalize_embeddings": True,
|
||||
@ -95,6 +107,8 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"model_name_or_path": "model",
|
||||
"device": "cuda",
|
||||
"use_auth_token": "the-token",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"normalize_embeddings": False,
|
||||
@ -106,6 +120,8 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert component.model_name_or_path == "model"
|
||||
assert component.device == "cuda"
|
||||
assert component.use_auth_token == "the-token"
|
||||
assert component.prefix == "prefix"
|
||||
assert component.suffix == "suffix"
|
||||
assert component.batch_size == 64
|
||||
assert component.progress_bar is False
|
||||
assert component.normalize_embeddings is False
|
||||
@ -194,3 +210,33 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
show_progress_bar=True,
|
||||
normalize_embeddings=False,
|
||||
)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prefix_suffix(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(
|
||||
model_name_or_path="model",
|
||||
prefix="my_prefix ",
|
||||
suffix=" my_suffix",
|
||||
metadata_fields_to_embed=["meta_field"],
|
||||
embedding_separator="\n",
|
||||
)
|
||||
embedder.embedding_backend = MagicMock()
|
||||
|
||||
documents = [
|
||||
Document(text=f"document number {i}", metadata={"meta_field": f"meta_value {i}"}) for i in range(5)
|
||||
]
|
||||
|
||||
embedder.run(documents=documents)
|
||||
|
||||
embedder.embedding_backend.embed.assert_called_once_with(
|
||||
[
|
||||
"my_prefix meta_value 0\ndocument number 0 my_suffix",
|
||||
"my_prefix meta_value 1\ndocument number 1 my_suffix",
|
||||
"my_prefix meta_value 2\ndocument number 2 my_suffix",
|
||||
"my_prefix meta_value 3\ndocument number 3 my_suffix",
|
||||
"my_prefix meta_value 4\ndocument number 4 my_suffix",
|
||||
],
|
||||
batch_size=32,
|
||||
show_progress_bar=True,
|
||||
normalize_embeddings=False,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user