haystack/test/components/embedders/test_sentence_transformers_text_embedder.py
György Orosz d2348ad462
feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode (#8806)
* feat: SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder can accept and pass any arguments to SentenceTransformer.encode

* refactor: encode_kwargs parameter of SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder mae to be the last positional parameter for backward compatibility reasons

* docs: added explanation for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder

* test: added tests for encode_kwargs in SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder

* doc: removed empty lines from docstrings of SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder

* refactor: encode_kwargs parameter of SentenceTransformersDocumentEmbedder and SentenceTransformersTextEmbedder mae to be the last positional parameter for backward compatibility (part II.)
2025-02-05 16:09:35 +00:00

317 lines
13 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import random
from unittest.mock import MagicMock, patch
import pytest
import torch
from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
from haystack.utils import ComponentDevice, Secret
class TestSentenceTransformersTextEmbedder:
def test_init_default(self):
embedder = SentenceTransformersTextEmbedder(model="model")
assert embedder.model == "model"
assert embedder.device == ComponentDevice.resolve_device(None)
assert embedder.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False
assert embedder.trust_remote_code is False
assert embedder.truncate_dim is None
assert embedder.precision == "float32"
def test_init_with_parameters(self):
embedder = SentenceTransformersTextEmbedder(
model="model",
device=ComponentDevice.from_str("cuda:0"),
token=Secret.from_token("fake-api-token"),
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
trust_remote_code=True,
truncate_dim=256,
precision="int8",
)
assert embedder.model == "model"
assert embedder.device == ComponentDevice.from_str("cuda:0")
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.normalize_embeddings is True
assert embedder.trust_remote_code is True
assert embedder.truncate_dim == 256
assert embedder.precision == "int8"
def test_to_dict(self):
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
"model": "model",
"device": ComponentDevice.from_str("cpu").to_dict(),
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"encode_kwargs": None,
"config_kwargs": None,
"precision": "float32",
},
}
def test_to_dict_with_custom_init_parameters(self):
component = SentenceTransformersTextEmbedder(
model="model",
device=ComponentDevice.from_str("cuda:0"),
token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix",
suffix="suffix",
batch_size=64,
progress_bar=False,
normalize_embeddings=True,
trust_remote_code=True,
truncate_dim=256,
model_kwargs={"torch_dtype": torch.float32},
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": False},
precision="int8",
encode_kwargs={"task": "clustering"},
)
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "model",
"device": ComponentDevice.from_str("cuda:0").to_dict(),
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
"trust_remote_code": True,
"truncate_dim": 256,
"model_kwargs": {"torch_dtype": "torch.float32"},
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": False},
"precision": "int8",
"encode_kwargs": {"task": "clustering"},
},
}
def test_to_dict_not_serialize_token(self):
component = SentenceTransformersTextEmbedder(model="model", token=Secret.from_token("fake-api-token"))
with pytest.raises(ValueError, match="Cannot serialize token-based secret"):
component.to_dict()
def test_from_dict(self):
data = {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
"model": "model",
"device": ComponentDevice.from_str("cpu").to_dict(),
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
"truncate_dim": None,
"model_kwargs": {"torch_dtype": "torch.float32"},
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": False},
"precision": "float32",
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
assert component.model == "model"
assert component.device == ComponentDevice.from_str("cpu")
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert component.prefix == ""
assert component.suffix == ""
assert component.batch_size == 32
assert component.progress_bar is True
assert component.normalize_embeddings is False
assert component.trust_remote_code is False
assert component.truncate_dim is None
assert component.model_kwargs == {"torch_dtype": torch.float32}
assert component.tokenizer_kwargs == {"model_max_length": 512}
assert component.config_kwargs == {"use_memory_efficient_attention": False}
assert component.precision == "float32"
def test_from_dict_no_default_parameters(self):
data = {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
assert component.model == "sentence-transformers/all-mpnet-base-v2"
assert component.device == ComponentDevice.resolve_device(None)
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert component.prefix == ""
assert component.suffix == ""
assert component.batch_size == 32
assert component.progress_bar is True
assert component.normalize_embeddings is False
assert component.trust_remote_code is False
assert component.truncate_dim is None
assert component.precision == "float32"
def test_from_dict_none_device(self):
data = {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
"model": "model",
"device": None,
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
"truncate_dim": 256,
"precision": "int8",
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
assert component.model == "model"
assert component.device == ComponentDevice.resolve_device(None)
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert component.prefix == ""
assert component.suffix == ""
assert component.batch_size == 32
assert component.progress_bar is True
assert component.normalize_embeddings is False
assert component.trust_remote_code is False
assert component.truncate_dim == 256
assert component.precision == "int8"
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup(self, mocked_factory):
embedder = SentenceTransformersTextEmbedder(
model="model",
token=None,
device=ComponentDevice.from_str("cpu"),
tokenizer_kwargs={"model_max_length": 512},
)
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
embedder.embedding_backend.model.max_seq_length = 512
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model",
device="cpu",
auth_token=None,
trust_remote_code=False,
truncate_dim=None,
model_kwargs=None,
tokenizer_kwargs={"model_max_length": 512},
config_kwargs=None,
)
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup_doesnt_reload(self, mocked_factory):
embedder = SentenceTransformersTextEmbedder(model="model")
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once()
def test_run(self):
embedder = SentenceTransformersTextEmbedder(model="model")
embedder.embedding_backend = MagicMock()
embedder.embedding_backend.embed = lambda x, **kwargs: [
[random.random() for _ in range(16)] for _ in range(len(x))
]
text = "a nice text to embed"
result = embedder.run(text=text)
embedding = result["embedding"]
assert isinstance(embedding, list)
assert all(isinstance(el, float) for el in embedding)
def test_run_wrong_input_format(self):
embedder = SentenceTransformersTextEmbedder(model="model")
embedder.embedding_backend = MagicMock()
list_integers_input = [1, 2, 3]
with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a string as input"):
embedder.run(text=list_integers_input)
@pytest.mark.integration
def test_run_trunc(self, monkeypatch):
"""
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector space
"""
monkeypatch.delenv("HF_API_TOKEN", raising=False) # https://github.com/deepset-ai/haystack/issues/8811
checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
text = "a nice text to embed"
embedder_def = SentenceTransformersTextEmbedder(model=checkpoint)
embedder_def.warm_up()
result_def = embedder_def.run(text=text)
embedding_def = result_def["embedding"]
embedder_trunc = SentenceTransformersTextEmbedder(model=checkpoint, truncate_dim=128)
embedder_trunc.warm_up()
result_trunc = embedder_trunc.run(text=text)
embedding_trunc = result_trunc["embedding"]
assert len(embedding_def) == 768
assert len(embedding_trunc) == 128
@pytest.mark.integration
def test_run_quantization(self):
"""
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector space
"""
checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
text = "a nice text to embed"
embedder_def = SentenceTransformersTextEmbedder(model=checkpoint, precision="int8")
embedder_def.warm_up()
result_def = embedder_def.run(text=text)
embedding_def = result_def["embedding"]
assert len(embedding_def) == 768
assert all(isinstance(el, int) for el in embedding_def)
def test_embed_encode_kwargs(self):
embedder = SentenceTransformersTextEmbedder(model="model", encode_kwargs={"task": "retrieval.query"})
embedder.embedding_backend = MagicMock()
text = "a nice text to embed"
embedder.run(text=text)
embedder.embedding_backend.embed.assert_called_once_with(
[text],
batch_size=32,
show_progress_bar=True,
normalize_embeddings=False,
precision="float32",
task="retrieval.query",
)