2023-08-28 17:23:26 +03:00
|
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
import pytest
|
2024-02-05 13:17:01 +01:00
|
|
|
from haystack.utils.auth import Secret
|
2023-08-28 17:23:26 +03:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
from haystack.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder
|
2023-08-28 17:23:26 +03:00
|
|
|
|
|
|
|
|
|
|
|
class TestSentenceTransformersTextEmbedder:
|
|
|
|
def test_init_default(self):
|
2024-01-12 15:30:17 +01:00
|
|
|
embedder = SentenceTransformersTextEmbedder(model="model")
|
|
|
|
assert embedder.model == "model"
|
2023-08-29 18:15:07 +02:00
|
|
|
assert embedder.device == "cpu"
|
2024-02-05 13:17:01 +01:00
|
|
|
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
2023-08-28 17:23:26 +03:00
|
|
|
assert embedder.prefix == ""
|
|
|
|
assert embedder.suffix == ""
|
|
|
|
assert embedder.batch_size == 32
|
|
|
|
assert embedder.progress_bar is True
|
|
|
|
assert embedder.normalize_embeddings is False
|
|
|
|
|
|
|
|
def test_init_with_parameters(self):
|
|
|
|
embedder = SentenceTransformersTextEmbedder(
|
2024-01-12 15:30:17 +01:00
|
|
|
model="model",
|
2023-08-29 18:15:07 +02:00
|
|
|
device="cuda",
|
2024-02-05 13:17:01 +01:00
|
|
|
token=Secret.from_token("fake-api-token"),
|
2023-08-28 17:23:26 +03:00
|
|
|
prefix="prefix",
|
|
|
|
suffix="suffix",
|
|
|
|
batch_size=64,
|
|
|
|
progress_bar=False,
|
|
|
|
normalize_embeddings=True,
|
|
|
|
)
|
2024-01-12 15:30:17 +01:00
|
|
|
assert embedder.model == "model"
|
2023-08-29 18:15:07 +02:00
|
|
|
assert embedder.device == "cuda"
|
2024-02-05 13:17:01 +01:00
|
|
|
assert embedder.token == Secret.from_token("fake-api-token")
|
2023-08-28 17:23:26 +03:00
|
|
|
assert embedder.prefix == "prefix"
|
|
|
|
assert embedder.suffix == "suffix"
|
|
|
|
assert embedder.batch_size == 64
|
|
|
|
assert embedder.progress_bar is False
|
|
|
|
assert embedder.normalize_embeddings is True
|
|
|
|
|
2023-08-29 18:15:07 +02:00
|
|
|
def test_to_dict(self):
|
2024-01-12 15:30:17 +01:00
|
|
|
component = SentenceTransformersTextEmbedder(model="model")
|
2023-08-29 18:15:07 +02:00
|
|
|
data = component.to_dict()
|
|
|
|
assert data == {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
2023-08-29 18:15:07 +02:00
|
|
|
"init_parameters": {
|
2024-02-05 13:17:01 +01:00
|
|
|
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
2024-01-12 15:30:17 +01:00
|
|
|
"model": "model",
|
2023-08-29 18:15:07 +02:00
|
|
|
"device": "cpu",
|
|
|
|
"prefix": "",
|
|
|
|
"suffix": "",
|
|
|
|
"batch_size": 32,
|
|
|
|
"progress_bar": True,
|
|
|
|
"normalize_embeddings": False,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
def test_to_dict_with_custom_init_parameters(self):
|
|
|
|
component = SentenceTransformersTextEmbedder(
|
2024-01-12 15:30:17 +01:00
|
|
|
model="model",
|
2023-08-29 18:15:07 +02:00
|
|
|
device="cuda",
|
2024-02-05 13:17:01 +01:00
|
|
|
token=Secret.from_env_var("ENV_VAR", strict=False),
|
2023-08-29 18:15:07 +02:00
|
|
|
prefix="prefix",
|
|
|
|
suffix="suffix",
|
|
|
|
batch_size=64,
|
|
|
|
progress_bar=False,
|
|
|
|
normalize_embeddings=True,
|
|
|
|
)
|
|
|
|
data = component.to_dict()
|
|
|
|
assert data == {
|
2023-11-24 14:48:43 +01:00
|
|
|
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
2023-08-29 18:15:07 +02:00
|
|
|
"init_parameters": {
|
2024-02-05 13:17:01 +01:00
|
|
|
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
2024-01-12 15:30:17 +01:00
|
|
|
"model": "model",
|
2023-08-29 18:15:07 +02:00
|
|
|
"device": "cuda",
|
|
|
|
"prefix": "prefix",
|
|
|
|
"suffix": "suffix",
|
|
|
|
"batch_size": 64,
|
|
|
|
"progress_bar": False,
|
|
|
|
"normalize_embeddings": True,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
2023-10-17 16:32:13 +02:00
|
|
|
def test_to_dict_not_serialize_token(self):
|
2024-02-05 13:17:01 +01:00
|
|
|
component = SentenceTransformersTextEmbedder(model="model", token=Secret.from_token("fake-api-token"))
|
|
|
|
with pytest.raises(ValueError, match="Cannot serialize token-based secret"):
|
|
|
|
component.to_dict()
|
2023-10-17 16:32:13 +02:00
|
|
|
|
2023-08-28 17:23:26 +03:00
|
|
|
@patch(
|
2023-11-24 14:48:43 +01:00
|
|
|
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
2023-08-28 17:23:26 +03:00
|
|
|
)
|
|
|
|
def test_warmup(self, mocked_factory):
|
2024-02-05 13:17:01 +01:00
|
|
|
embedder = SentenceTransformersTextEmbedder(model="model", token=None)
|
2023-08-28 17:23:26 +03:00
|
|
|
mocked_factory.get_embedding_backend.assert_not_called()
|
|
|
|
embedder.warm_up()
|
2024-02-05 13:17:01 +01:00
|
|
|
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
|
2023-08-28 17:23:26 +03:00
|
|
|
|
|
|
|
@patch(
|
2023-11-24 14:48:43 +01:00
|
|
|
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
2023-08-28 17:23:26 +03:00
|
|
|
)
|
|
|
|
def test_warmup_doesnt_reload(self, mocked_factory):
|
2024-01-12 15:30:17 +01:00
|
|
|
embedder = SentenceTransformersTextEmbedder(model="model")
|
2023-08-28 17:23:26 +03:00
|
|
|
mocked_factory.get_embedding_backend.assert_not_called()
|
|
|
|
embedder.warm_up()
|
|
|
|
embedder.warm_up()
|
|
|
|
mocked_factory.get_embedding_backend.assert_called_once()
|
|
|
|
|
|
|
|
def test_run(self):
|
2024-01-12 15:30:17 +01:00
|
|
|
embedder = SentenceTransformersTextEmbedder(model="model")
|
2023-08-28 17:23:26 +03:00
|
|
|
embedder.embedding_backend = MagicMock()
|
|
|
|
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist()
|
|
|
|
|
2023-09-14 12:38:24 +02:00
|
|
|
text = "a nice text to embed"
|
2023-08-28 17:23:26 +03:00
|
|
|
|
2023-09-14 12:38:24 +02:00
|
|
|
result = embedder.run(text=text)
|
|
|
|
embedding = result["embedding"]
|
2023-08-28 17:23:26 +03:00
|
|
|
|
2023-09-14 12:38:24 +02:00
|
|
|
assert isinstance(embedding, list)
|
|
|
|
assert all(isinstance(el, float) for el in embedding)
|
2023-08-28 17:23:26 +03:00
|
|
|
|
|
|
|
def test_run_wrong_input_format(self):
|
2024-01-12 15:30:17 +01:00
|
|
|
embedder = SentenceTransformersTextEmbedder(model="model")
|
2023-08-28 17:23:26 +03:00
|
|
|
embedder.embedding_backend = MagicMock()
|
|
|
|
|
|
|
|
list_integers_input = [1, 2, 3]
|
|
|
|
|
2023-09-14 12:38:24 +02:00
|
|
|
with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a string as input"):
|
|
|
|
embedder.run(text=list_integers_input)
|