2024-05-09 15:40:36 +02:00
|
|
|
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
|
|
#
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2023-08-28 13:32:37 +03:00
|
|
|
from unittest.mock import patch
|
2024-03-14 11:14:04 +01:00
|
|
|
|
2023-08-28 13:32:37 +03:00
|
|
|
import pytest
|
2024-03-14 11:14:04 +01:00
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
from haystack.components.embedders.backends.sentence_transformers_backend import (
|
2023-08-28 13:32:37 +03:00
|
|
|
_SentenceTransformersEmbeddingBackendFactory,
|
|
|
|
)
|
2024-02-05 13:17:01 +01:00
|
|
|
from haystack.utils.auth import Secret
|
2023-08-28 13:32:37 +03:00
|
|
|
|
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
2023-08-28 13:32:37 +03:00
|
|
|
def test_factory_behavior(mock_sentence_transformer):
|
|
|
|
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
2024-01-12 15:30:17 +01:00
|
|
|
model="my_model", device="cpu"
|
2023-08-28 13:32:37 +03:00
|
|
|
)
|
|
|
|
same_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend("my_model", "cpu")
|
|
|
|
another_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
2024-01-12 15:30:17 +01:00
|
|
|
model="another_model", device="cpu"
|
2023-08-28 13:32:37 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
assert same_embedding_backend is embedding_backend
|
|
|
|
assert another_embedding_backend is not embedding_backend
|
|
|
|
|
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
2023-08-28 13:32:37 +03:00
|
|
|
def test_model_initialization(mock_sentence_transformer):
|
|
|
|
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
2024-07-26 10:39:48 +02:00
|
|
|
model="model",
|
|
|
|
device="cpu",
|
|
|
|
auth_token=Secret.from_token("fake-api-token"),
|
|
|
|
trust_remote_code=True,
|
|
|
|
truncate_dim=256,
|
2023-08-28 13:32:37 +03:00
|
|
|
)
|
|
|
|
mock_sentence_transformer.assert_called_once_with(
|
2024-07-26 10:39:48 +02:00
|
|
|
model_name_or_path="model",
|
|
|
|
device="cpu",
|
|
|
|
use_auth_token="fake-api-token",
|
|
|
|
trust_remote_code=True,
|
|
|
|
truncate_dim=256,
|
2024-08-02 10:37:10 +02:00
|
|
|
model_kwargs=None,
|
|
|
|
tokenizer_kwargs=None,
|
2023-08-28 13:32:37 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-11-24 14:48:43 +01:00
|
|
|
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
2023-08-28 13:32:37 +03:00
|
|
|
def test_embedding_function_with_kwargs(mock_sentence_transformer):
|
2024-01-12 15:30:17 +01:00
|
|
|
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model="model")
|
2023-08-28 13:32:37 +03:00
|
|
|
|
|
|
|
data = ["sentence1", "sentence2"]
|
|
|
|
embedding_backend.embed(data=data, normalize_embeddings=True)
|
|
|
|
|
|
|
|
embedding_backend.model.encode.assert_called_once_with(data, normalize_embeddings=True)
|