From 283ecf2760865abbf0803710067aecb457534aeb Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Wed, 13 Sep 2023 12:55:06 +0200 Subject: [PATCH] feat: add `prefix` and `suffix` to `SentenceTransformersDocumentEmbedder` (#5745) * add prefix and suffix * fix test --- .../preview/components/embedders/__init__.py | 6 +++ ...sentence_transformers_document_embedder.py | 12 ++++- ...bedder-prefix-suffix-442412c553135406.yaml | 7 +++ ...sentence_transformers_document_embedder.py | 46 +++++++++++++++++++ 4 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/sentence-transformers-doc-embedder-prefix-suffix-442412c553135406.yaml diff --git a/haystack/preview/components/embedders/__init__.py b/haystack/preview/components/embedders/__init__.py index e69de29bb..de8b93958 100644 --- a/haystack/preview/components/embedders/__init__.py +++ b/haystack/preview/components/embedders/__init__.py @@ -0,0 +1,6 @@ +from haystack.preview.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder +from haystack.preview.components.embedders.sentence_transformers_document_embedder import ( + SentenceTransformersDocumentEmbedder, +) + +__all__ = ["SentenceTransformersTextEmbedder", "SentenceTransformersDocumentEmbedder"] diff --git a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py index 28a6a15fd..40ac170aa 100644 --- a/haystack/preview/components/embedders/sentence_transformers_document_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_document_embedder.py @@ -18,6 +18,8 @@ class SentenceTransformersDocumentEmbedder: model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2", device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None, + prefix: str = "", + suffix: str = "", batch_size: int = 32, progress_bar: bool = True, normalize_embeddings: bool = False, @@ -32,6 +34,8 @@ class SentenceTransformersDocumentEmbedder: :param use_auth_token: The API token used to download private models from Hugging Face. If this parameter is set to `True`, then the token generated when running `transformers-cli login` (stored in ~/.huggingface) will be used. + :param prefix: A string to add to the beginning of each Document text before embedding. + :param suffix: A string to add to the end of each Document text before embedding. :param batch_size: Number of strings to encode at once. :param progress_bar: If true, displays progress bar during embedding. :param normalize_embeddings: If set to true, returned vectors will have length 1. @@ -43,6 +47,8 @@ class SentenceTransformersDocumentEmbedder: # TODO: remove device parameter and use Haystack's device management once migrated self.device = device or "cpu" self.use_auth_token = use_auth_token + self.prefix = prefix + self.suffix = suffix self.batch_size = batch_size self.progress_bar = progress_bar self.normalize_embeddings = normalize_embeddings @@ -58,6 +64,8 @@ class SentenceTransformersDocumentEmbedder: model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.use_auth_token, + prefix=self.prefix, + suffix=self.suffix, batch_size=self.batch_size, progress_bar=self.progress_bar, normalize_embeddings=self.normalize_embeddings, @@ -104,7 +112,9 @@ class SentenceTransformersDocumentEmbedder: for key in self.metadata_fields_to_embed if key in doc.metadata and doc.metadata[key] ] - text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.text or ""]) + text_to_embed = ( + self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.text or ""]) + self.suffix + ) texts_to_embed.append(text_to_embed) embeddings = self.embedding_backend.embed( diff --git a/releasenotes/notes/sentence-transformers-doc-embedder-prefix-suffix-442412c553135406.yaml b/releasenotes/notes/sentence-transformers-doc-embedder-prefix-suffix-442412c553135406.yaml new file mode 100644 index 000000000..f5f88f83f --- /dev/null +++ b/releasenotes/notes/sentence-transformers-doc-embedder-prefix-suffix-442412c553135406.yaml @@ -0,0 +1,7 @@ +--- +preview: + - | + Add `prefix` and `suffix` attributes to `SentenceTransformersDocumentEmbedder`. + They can be used to add a prefix and suffix to the Document text before + embedding it. This is necessary to take full advantage of modern embedding + models, such as E5. diff --git a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py index 5805de0bc..c861c9c48 100644 --- a/test/preview/components/embedders/test_sentence_transformers_document_embedder.py +++ b/test/preview/components/embedders/test_sentence_transformers_document_embedder.py @@ -15,6 +15,8 @@ class TestSentenceTransformersDocumentEmbedder: assert embedder.model_name_or_path == "model" assert embedder.device == "cpu" assert embedder.use_auth_token is None + assert embedder.prefix == "" + assert embedder.suffix == "" assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.normalize_embeddings is False @@ -27,6 +29,8 @@ class TestSentenceTransformersDocumentEmbedder: model_name_or_path="model", device="cuda", use_auth_token=True, + prefix="prefix", + suffix="suffix", batch_size=64, progress_bar=False, normalize_embeddings=True, @@ -36,6 +40,8 @@ class TestSentenceTransformersDocumentEmbedder: assert embedder.model_name_or_path == "model" assert embedder.device == "cuda" assert embedder.use_auth_token is True + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.normalize_embeddings is True @@ -52,6 +58,8 @@ class TestSentenceTransformersDocumentEmbedder: "model_name_or_path": "model", "device": "cpu", "use_auth_token": None, + "prefix": "", + "suffix": "", "batch_size": 32, "progress_bar": True, "normalize_embeddings": False, @@ -66,6 +74,8 @@ class TestSentenceTransformersDocumentEmbedder: model_name_or_path="model", device="cuda", use_auth_token="the-token", + prefix="prefix", + suffix="suffix", batch_size=64, progress_bar=False, normalize_embeddings=True, @@ -79,6 +89,8 @@ class TestSentenceTransformersDocumentEmbedder: "model_name_or_path": "model", "device": "cuda", "use_auth_token": "the-token", + "prefix": "prefix", + "suffix": "suffix", "batch_size": 64, "progress_bar": False, "normalize_embeddings": True, @@ -95,6 +107,8 @@ class TestSentenceTransformersDocumentEmbedder: "model_name_or_path": "model", "device": "cuda", "use_auth_token": "the-token", + "prefix": "prefix", + "suffix": "suffix", "batch_size": 64, "progress_bar": False, "normalize_embeddings": False, @@ -106,6 +120,8 @@ class TestSentenceTransformersDocumentEmbedder: assert component.model_name_or_path == "model" assert component.device == "cuda" assert component.use_auth_token == "the-token" + assert component.prefix == "prefix" + assert component.suffix == "suffix" assert component.batch_size == 64 assert component.progress_bar is False assert component.normalize_embeddings is False @@ -194,3 +210,33 @@ class TestSentenceTransformersDocumentEmbedder: show_progress_bar=True, normalize_embeddings=False, ) + + @pytest.mark.unit + def test_prefix_suffix(self): + embedder = SentenceTransformersDocumentEmbedder( + model_name_or_path="model", + prefix="my_prefix ", + suffix=" my_suffix", + metadata_fields_to_embed=["meta_field"], + embedding_separator="\n", + ) + embedder.embedding_backend = MagicMock() + + documents = [ + Document(text=f"document number {i}", metadata={"meta_field": f"meta_value {i}"}) for i in range(5) + ] + + embedder.run(documents=documents) + + embedder.embedding_backend.embed.assert_called_once_with( + [ + "my_prefix meta_value 0\ndocument number 0 my_suffix", + "my_prefix meta_value 1\ndocument number 1 my_suffix", + "my_prefix meta_value 2\ndocument number 2 my_suffix", + "my_prefix meta_value 3\ndocument number 3 my_suffix", + "my_prefix meta_value 4\ndocument number 4 my_suffix", + ], + batch_size=32, + show_progress_bar=True, + normalize_embeddings=False, + )