mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 15:08:43 +00:00
added precision parameter to sentence transformers embeddings (#8179)
* added `precision` parameter to sentence transformers embeddings * fixed test * Update haystack/components/embedders/sentence_transformers_document_embedder.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * Update test/components/embedders/test_sentence_transformers_text_embedder.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * Update test/components/embedders/test_sentence_transformers_text_embedder.py Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com> * fix format * Update sentence_transformers_text_embedder.py --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
parent
ec02817f14
commit
4c798470b2
@ -2,7 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.components.embedders.backends.sentence_transformers_backend import (
|
||||
@ -54,6 +54,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
truncate_dim: Optional[int] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
|
||||
):
|
||||
"""
|
||||
Creates a SentenceTransformersDocumentEmbedder component.
|
||||
@ -95,6 +96,11 @@ class SentenceTransformersDocumentEmbedder:
|
||||
:param tokenizer_kwargs:
|
||||
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
|
||||
Refer to specific model documentation for available kwargs.
|
||||
:param precision:
|
||||
The precision to use for the embeddings.
|
||||
All non-float32 precisions are quantized embeddings.
|
||||
Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy.
|
||||
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@ -112,6 +118,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
self.model_kwargs = model_kwargs
|
||||
self.tokenizer_kwargs = tokenizer_kwargs
|
||||
self.embedding_backend = None
|
||||
self.precision = precision
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -142,6 +149,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
truncate_dim=self.truncate_dim,
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
precision=self.precision,
|
||||
)
|
||||
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
|
||||
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
||||
@ -215,6 +223,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
batch_size=self.batch_size,
|
||||
show_progress_bar=self.progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
precision=self.precision,
|
||||
)
|
||||
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.components.embedders.backends.sentence_transformers_backend import (
|
||||
@ -48,6 +48,7 @@ class SentenceTransformersTextEmbedder:
|
||||
truncate_dim: Optional[int] = None,
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
|
||||
):
|
||||
"""
|
||||
Create a SentenceTransformersTextEmbedder component.
|
||||
@ -85,6 +86,11 @@ class SentenceTransformersTextEmbedder:
|
||||
:param tokenizer_kwargs:
|
||||
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
|
||||
Refer to specific model documentation for available kwargs.
|
||||
:param precision:
|
||||
The precision to use for the embeddings.
|
||||
All non-float32 precisions are quantized embeddings.
|
||||
Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy.
|
||||
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
|
||||
"""
|
||||
|
||||
self.model = model
|
||||
@ -100,6 +106,7 @@ class SentenceTransformersTextEmbedder:
|
||||
self.model_kwargs = model_kwargs
|
||||
self.tokenizer_kwargs = tokenizer_kwargs
|
||||
self.embedding_backend = None
|
||||
self.precision = precision
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -128,6 +135,7 @@ class SentenceTransformersTextEmbedder:
|
||||
truncate_dim=self.truncate_dim,
|
||||
model_kwargs=self.model_kwargs,
|
||||
tokenizer_kwargs=self.tokenizer_kwargs,
|
||||
precision=self.precision,
|
||||
)
|
||||
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
|
||||
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
||||
@ -192,5 +200,6 @@ class SentenceTransformersTextEmbedder:
|
||||
batch_size=self.batch_size,
|
||||
show_progress_bar=self.progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
precision=self.precision,
|
||||
)[0]
|
||||
return {"embedding": embedding}
|
||||
|
||||
5
releasenotes/notes/release-note-42273d88ce3e2b2e.yaml
Normal file
5
releasenotes/notes/release-note-42273d88ce3e2b2e.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add `precision` parameter to Sentence Transformers Embedders, which allows quantized
|
||||
embeddings. Especially useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
|
||||
@ -27,6 +27,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert embedder.embedding_separator == "\n"
|
||||
assert embedder.trust_remote_code is False
|
||||
assert embedder.truncate_dim is None
|
||||
assert embedder.precision == "float32"
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
embedder = SentenceTransformersDocumentEmbedder(
|
||||
@ -42,6 +43,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
embedding_separator=" | ",
|
||||
trust_remote_code=True,
|
||||
truncate_dim=256,
|
||||
precision="int8",
|
||||
)
|
||||
assert embedder.model == "model"
|
||||
assert embedder.device == ComponentDevice.from_str("cuda:0")
|
||||
@ -55,6 +57,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert embedder.embedding_separator == " | "
|
||||
assert embedder.trust_remote_code
|
||||
assert embedder.truncate_dim == 256
|
||||
assert embedder.precision == "int8"
|
||||
|
||||
def test_to_dict(self):
|
||||
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
|
||||
@ -76,6 +79,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"truncate_dim": None,
|
||||
"model_kwargs": None,
|
||||
"tokenizer_kwargs": None,
|
||||
"precision": "float32",
|
||||
},
|
||||
}
|
||||
|
||||
@ -95,6 +99,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
truncate_dim=256,
|
||||
model_kwargs={"torch_dtype": torch.float32},
|
||||
tokenizer_kwargs={"model_max_length": 512},
|
||||
precision="int8",
|
||||
)
|
||||
data = component.to_dict()
|
||||
|
||||
@ -115,6 +120,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"truncate_dim": 256,
|
||||
"model_kwargs": {"torch_dtype": "torch.float32"},
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"precision": "int8",
|
||||
},
|
||||
}
|
||||
|
||||
@ -134,6 +140,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"truncate_dim": 256,
|
||||
"model_kwargs": {"torch_dtype": "torch.float32"},
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"precision": "int8",
|
||||
}
|
||||
component = SentenceTransformersDocumentEmbedder.from_dict(
|
||||
{
|
||||
@ -155,6 +162,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert component.truncate_dim == 256
|
||||
assert component.model_kwargs == {"torch_dtype": torch.float32}
|
||||
assert component.tokenizer_kwargs == {"model_max_length": 512}
|
||||
assert component.precision == "int8"
|
||||
|
||||
def test_from_dict_no_default_parameters(self):
|
||||
component = SentenceTransformersDocumentEmbedder.from_dict(
|
||||
@ -175,6 +183,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert component.trust_remote_code is False
|
||||
assert component.meta_fields_to_embed == []
|
||||
assert component.truncate_dim is None
|
||||
assert component.precision == "float32"
|
||||
|
||||
def test_from_dict_none_device(self):
|
||||
init_parameters = {
|
||||
@ -190,6 +199,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"meta_fields_to_embed": ["meta_field"],
|
||||
"trust_remote_code": True,
|
||||
"truncate_dim": None,
|
||||
"precision": "float32",
|
||||
}
|
||||
component = SentenceTransformersDocumentEmbedder.from_dict(
|
||||
{
|
||||
@ -209,6 +219,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
assert component.trust_remote_code
|
||||
assert component.meta_fields_to_embed == ["meta_field"]
|
||||
assert component.truncate_dim is None
|
||||
assert component.precision == "float32"
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
@ -292,6 +303,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
batch_size=32,
|
||||
show_progress_bar=True,
|
||||
normalize_embeddings=False,
|
||||
precision="float32",
|
||||
)
|
||||
|
||||
def test_prefix_suffix(self):
|
||||
@ -319,4 +331,5 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
batch_size=32,
|
||||
show_progress_bar=True,
|
||||
normalize_embeddings=False,
|
||||
precision="float32",
|
||||
)
|
||||
|
||||
@ -24,6 +24,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert embedder.normalize_embeddings is False
|
||||
assert embedder.trust_remote_code is False
|
||||
assert embedder.truncate_dim is None
|
||||
assert embedder.precision == "float32"
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
embedder = SentenceTransformersTextEmbedder(
|
||||
@ -37,6 +38,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
normalize_embeddings=True,
|
||||
trust_remote_code=True,
|
||||
truncate_dim=256,
|
||||
precision="int8",
|
||||
)
|
||||
assert embedder.model == "model"
|
||||
assert embedder.device == ComponentDevice.from_str("cuda:0")
|
||||
@ -48,6 +50,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert embedder.normalize_embeddings is True
|
||||
assert embedder.trust_remote_code is True
|
||||
assert embedder.truncate_dim == 256
|
||||
assert embedder.precision == "int8"
|
||||
|
||||
def test_to_dict(self):
|
||||
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
|
||||
@ -67,6 +70,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"truncate_dim": None,
|
||||
"model_kwargs": None,
|
||||
"tokenizer_kwargs": None,
|
||||
"precision": "float32",
|
||||
},
|
||||
}
|
||||
|
||||
@ -84,6 +88,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
truncate_dim=256,
|
||||
model_kwargs={"torch_dtype": torch.float32},
|
||||
tokenizer_kwargs={"model_max_length": 512},
|
||||
precision="int8",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
@ -101,6 +106,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"truncate_dim": 256,
|
||||
"model_kwargs": {"torch_dtype": "torch.float32"},
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"precision": "int8",
|
||||
},
|
||||
}
|
||||
|
||||
@ -125,6 +131,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"truncate_dim": None,
|
||||
"model_kwargs": {"torch_dtype": "torch.float32"},
|
||||
"tokenizer_kwargs": {"model_max_length": 512},
|
||||
"precision": "float32",
|
||||
},
|
||||
}
|
||||
component = SentenceTransformersTextEmbedder.from_dict(data)
|
||||
@ -140,6 +147,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert component.truncate_dim is None
|
||||
assert component.model_kwargs == {"torch_dtype": torch.float32}
|
||||
assert component.tokenizer_kwargs == {"model_max_length": 512}
|
||||
assert component.precision == "float32"
|
||||
|
||||
def test_from_dict_no_default_parameters(self):
|
||||
data = {
|
||||
@ -157,6 +165,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert component.normalize_embeddings is False
|
||||
assert component.trust_remote_code is False
|
||||
assert component.truncate_dim is None
|
||||
assert component.precision == "float32"
|
||||
|
||||
def test_from_dict_none_device(self):
|
||||
data = {
|
||||
@ -172,6 +181,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
"normalize_embeddings": False,
|
||||
"trust_remote_code": False,
|
||||
"truncate_dim": 256,
|
||||
"precision": "int8",
|
||||
},
|
||||
}
|
||||
component = SentenceTransformersTextEmbedder.from_dict(data)
|
||||
@ -185,6 +195,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert component.normalize_embeddings is False
|
||||
assert component.trust_remote_code is False
|
||||
assert component.truncate_dim == 256
|
||||
assert component.precision == "int8"
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
@ -255,3 +266,19 @@ class TestSentenceTransformersTextEmbedder:
|
||||
|
||||
assert len(embedding_def) == 768
|
||||
assert len(embedding_trunc) == 128
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_quantization(self):
|
||||
"""
|
||||
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector space
|
||||
"""
|
||||
checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
|
||||
text = "a nice text to embed"
|
||||
|
||||
embedder_def = SentenceTransformersTextEmbedder(model=checkpoint, precision="int8")
|
||||
embedder_def.warm_up()
|
||||
result_def = embedder_def.run(text=text)
|
||||
embedding_def = result_def["embedding"]
|
||||
|
||||
assert len(embedding_def) == 768
|
||||
assert all(isinstance(el, int) for el in embedding_def)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user