mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode (#8806)
* feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode * refactor: encode_kwargs parameter of SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder mae to be the last positional parameter for backward compatibility reasons * docs: added explanation for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder * test: added tests for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder * doc: removed empty lines from docstrings of SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder * refactor: encode_kwargs parameter of SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder mae to be the last positional parameter for backward compatibility (part II.)
This commit is contained in:
parent
2828d9e4ae
commit
d2348ad462
@ -56,6 +56,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
config_kwargs: Optional[Dict[str, Any]] = None,
|
||||
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
|
||||
encode_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Creates a SentenceTransformersDocumentEmbedder component.
|
||||
@ -104,6 +105,10 @@ class SentenceTransformersDocumentEmbedder:
|
||||
All non-float32 precisions are quantized embeddings.
|
||||
Quantized embeddings are smaller and faster to compute, but may have a lower accuracy.
|
||||
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
|
||||
:param encode_kwargs:
|
||||
Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
|
||||
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
|
||||
avoid passing parameters that change the output type.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@ -121,6 +126,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
self.model_kwargs = model_kwargs
|
||||
self.tokenizer_kwargs = tokenizer_kwargs
|
||||
self.config_kwargs = config_kwargs
|
||||
self.encode_kwargs = encode_kwargs
|
||||
self.embedding_backend = None
|
||||
self.precision = precision
|
||||
|
||||
@ -155,6 +161,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
config_kwargs=self.config_kwargs,
|
||||
precision=self.precision,
|
||||
encode_kwargs=self.encode_kwargs,
|
||||
)
|
||||
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
|
||||
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
||||
@ -232,6 +239,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
show_progress_bar=self.progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
precision=self.precision,
|
||||
**(self.encode_kwargs if self.encode_kwargs else {}),
|
||||
)
|
||||
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
|
||||
@ -50,6 +50,7 @@ class SentenceTransformersTextEmbedder:
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
config_kwargs: Optional[Dict[str, Any]] = None,
|
||||
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
|
||||
encode_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Create a SentenceTransformersTextEmbedder component.
|
||||
@ -94,6 +95,10 @@ class SentenceTransformersTextEmbedder:
|
||||
All non-float32 precisions are quantized embeddings.
|
||||
Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy.
|
||||
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
|
||||
:param encode_kwargs:
|
||||
Additional keyword arguments for `SentenceTransformer.encode` when embedding texts.
|
||||
This parameter is provided for fine customization. Be careful not to clash with already set parameters and
|
||||
avoid passing parameters that change the output type.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@ -109,6 +114,7 @@ class SentenceTransformersTextEmbedder:
|
||||
self.model_kwargs = model_kwargs
|
||||
self.tokenizer_kwargs = tokenizer_kwargs
|
||||
self.config_kwargs = config_kwargs
|
||||
self.encode_kwargs = encode_kwargs
|
||||
self.embedding_backend = None
|
||||
self.precision = precision
|
||||
|
||||
@ -141,6 +147,7 @@ class SentenceTransformersTextEmbedder:
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
config_kwargs=self.config_kwargs,
|
||||
precision=self.precision,
|
||||
encode_kwargs=self.encode_kwargs,
|
||||
)
|
||||
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
|
||||
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
||||
@ -209,5 +216,6 @@ class SentenceTransformersTextEmbedder:
|
||||
show_progress_bar=self.progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
precision=self.precision,
|
||||
**(self.encode_kwargs if self.encode_kwargs else {}),
|
||||
)[0]
|
||||
return {"embedding": embedding}
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Enhanced `SentenceTransformersDocumentEmbedder` and `SentenceTransformersTextEmbedder` to accept
|
||||
an additional parameter, which is passed directly to the underlying `SentenceTransformer.encode` method
|
||||
for greater flexibility in embedding customization.
|
||||
@ -1,9 +1,9 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import random
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import random
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -79,6 +79,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"truncate_dim": None,
|
||||
"model_kwargs": None,
|
||||
"tokenizer_kwargs": None,
|
||||
"encode_kwargs": None,
|
||||
"config_kwargs": None,
|
||||
"precision": "float32",
|
||||
},
|
||||
@ -102,6 +103,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
tokenizer_kwargs={"model_max_length": 512},
|
||||
config_kwargs={"use_memory_efficient_attention": True},
|
||||
precision="int8",
|
||||
encode_kwargs={"task": "clustering"},
|
||||
)
|
||||
data = component.to_dict()
|
||||
|
||||
@ -124,6 +126,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"config_kwargs": {"use_memory_efficient_attention": True},
|
||||
"precision": "int8",
|
||||
"encode_kwargs": {"task": "clustering"},
|
||||
},
|
||||
}
|
||||
|
||||
@ -316,6 +319,20 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
precision="float32",
|
||||
)
|
||||
|
||||
def test_embed_encode_kwargs(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(model="model", encode_kwargs={"task": "retrieval.passage"})
|
||||
embedder.embedding_backend = MagicMock()
|
||||
documents = [Document(content=f"document number {i}") for i in range(5)]
|
||||
embedder.run(documents=documents)
|
||||
embedder.embedding_backend.embed.assert_called_once_with(
|
||||
["document number 0", "document number 1", "document number 2", "document number 3", "document number 4"],
|
||||
batch_size=32,
|
||||
show_progress_bar=True,
|
||||
normalize_embeddings=False,
|
||||
precision="float32",
|
||||
task="retrieval.passage",
|
||||
)
|
||||
|
||||
def test_prefix_suffix(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(
|
||||
model="model",
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import random
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
import random
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
|
||||
from haystack.utils import ComponentDevice, Secret
|
||||
@ -70,6 +70,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"truncate_dim": None,
|
||||
"model_kwargs": None,
|
||||
"tokenizer_kwargs": None,
|
||||
"encode_kwargs": None,
|
||||
"config_kwargs": None,
|
||||
"precision": "float32",
|
||||
},
|
||||
@ -91,6 +92,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
tokenizer_kwargs={"model_max_length": 512},
|
||||
config_kwargs={"use_memory_efficient_attention": False},
|
||||
precision="int8",
|
||||
encode_kwargs={"task": "clustering"},
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
@ -110,6 +112,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"config_kwargs": {"use_memory_efficient_attention": False},
|
||||
"precision": "int8",
|
||||
"encode_kwargs": {"task": "clustering"},
|
||||
},
|
||||
}
|
||||
|
||||
@ -297,3 +300,17 @@ class TestSentenceTransformersTextEmbedder:
|
||||
|
||||
assert len(embedding_def) == 768
|
||||
assert all(isinstance(el, int) for el in embedding_def)
|
||||
|
||||
def test_embed_encode_kwargs(self):
|
||||
embedder = SentenceTransformersTextEmbedder(model="model", encode_kwargs={"task": "retrieval.query"})
|
||||
embedder.embedding_backend = MagicMock()
|
||||
text = "a nice text to embed"
|
||||
embedder.run(text=text)
|
||||
embedder.embedding_backend.embed.assert_called_once_with(
|
||||
[text],
|
||||
batch_size=32,
|
||||
show_progress_bar=True,
|
||||
normalize_embeddings=False,
|
||||
precision="float32",
|
||||
task="retrieval.query",
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user