mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-24 01:10:45 +00:00

* first draft * incorporate feedback * some unit tests * release notes * real release notes * refactored to use a factory class * allow forcing fresh instances * Update haystack/preview/embedding_backends/sentence_transformers_backend.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * simplify implementation and tests * make factory private * change return type; improve tests * fix typing * rm unused import --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
43 lines
1.8 KiB
Python
43 lines
1.8 KiB
Python
from unittest.mock import patch
|
|
import pytest
|
|
from haystack.preview.embedding_backends.sentence_transformers_backend import (
|
|
_SentenceTransformersEmbeddingBackendFactory,
|
|
)
|
|
|
|
|
|
@pytest.mark.unit
|
|
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
|
|
def test_factory_behavior(mock_sentence_transformer):
|
|
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
|
model_name_or_path="my_model", device="cpu"
|
|
)
|
|
same_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend("my_model", "cpu")
|
|
another_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
|
model_name_or_path="another_model", device="cpu"
|
|
)
|
|
|
|
assert same_embedding_backend is embedding_backend
|
|
assert another_embedding_backend is not embedding_backend
|
|
|
|
|
|
@pytest.mark.unit
|
|
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
|
|
def test_model_initialization(mock_sentence_transformer):
|
|
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
|
model_name_or_path="model", device="cpu", use_auth_token="my_token"
|
|
)
|
|
mock_sentence_transformer.assert_called_once_with(
|
|
model_name_or_path="model", device="cpu", use_auth_token="my_token"
|
|
)
|
|
|
|
|
|
@pytest.mark.unit
|
|
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
|
|
def test_embedding_function_with_kwargs(mock_sentence_transformer):
|
|
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model")
|
|
|
|
data = ["sentence1", "sentence2"]
|
|
embedding_backend.embed(data=data, normalize_embeddings=True)
|
|
|
|
embedding_backend.model.encode.assert_called_once_with(data, normalize_embeddings=True)
|