haystack/test/components/embedders/test_sentence_transformers_embedding_backend.py
Arseniy Shkunkov 1fb76ec7e4
feat: add Sparse Embedders based on Sentence Transformers (#9588)
* Added backend class for SparseEncoder and also SentenceTransformersSparseTextEmbedder

* Added SentenceTransformersSparseDocumentEmbedder

* Created a separate _SentenceTransformersSparseEmbeddingBackendFactory and added tests

* Remove unused parameter

* Wrapped output into SparseEmbedding dataclass + fix tests

* Return correct SparseEmbedding, imports and tests

* fix fmt

* Style changes and fixes

* Added a test for embed function

* Added integration test and fixed some other tests

* Add lint fixes

* Fixed positional arguments

* fix types, simplify and more

* fix

* token fixes

* pydocs, small model in test, cache improvement

* try 3.9 for docs

* better to pin click

* release note

* small fix

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
2025-09-19 14:00:13 +00:00

60 lines
2.2 KiB
Python

# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import patch
from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
from haystack.utils.auth import Secret
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_factory_behavior(mock_sentence_transformer):
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="my_model", device="cpu"
)
same_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend("my_model", "cpu")
another_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="another_model", device="cpu"
)
assert same_embedding_backend is embedding_backend
assert another_embedding_backend is not embedding_backend
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_model_initialization(mock_sentence_transformer):
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="model",
device="cpu",
auth_token=Secret.from_token("fake-api-token"),
trust_remote_code=True,
local_files_only=True,
truncate_dim=256,
backend="torch",
)
mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model",
device="cpu",
token="fake-api-token",
trust_remote_code=True,
local_files_only=True,
truncate_dim=256,
model_kwargs=None,
tokenizer_kwargs=None,
config_kwargs=None,
backend="torch",
)
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_embedding_function_with_kwargs(mock_sentence_transformer):
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model="model")
data = ["sentence1", "sentence2"]
embedding_backend.embed(data=data, normalize_embeddings=True)
embedding_backend.model.encode.assert_called_once_with(data, normalize_embeddings=True)