haystack/test/preview/embedding_backends/test_sentence_transformers.py
Stefano Fiorucci 35dfe47186
feat: SentenceTransformersEmbeddingBackend (v2) (#5572)
* first draft

* incorporate feedback

* some unit tests

* release notes

* real release notes

* refactored to use a factory class

* allow forcing fresh instances

* Update haystack/preview/embedding_backends/sentence_transformers_backend.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* simplify implementation and tests

* make factory private

* change return type; improve tests

* fix typing

* rm unused import

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
Co-authored-by: ZanSara <sara.zanzottera@deepset.ai>
2023-08-28 12:32:37 +02:00

43 lines
1.8 KiB
Python

from unittest.mock import patch
import pytest
from haystack.preview.embedding_backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@pytest.mark.unit
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
def test_factory_behavior(mock_sentence_transformer):
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="my_model", device="cpu"
)
same_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend("my_model", "cpu")
another_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="another_model", device="cpu"
)
assert same_embedding_backend is embedding_backend
assert another_embedding_backend is not embedding_backend
@pytest.mark.unit
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
def test_model_initialization(mock_sentence_transformer):
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="model", device="cpu", use_auth_token="my_token"
)
mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model", device="cpu", use_auth_token="my_token"
)
@pytest.mark.unit
@patch("haystack.preview.embedding_backends.sentence_transformers_backend.SentenceTransformer")
def test_embedding_function_with_kwargs(mock_sentence_transformer):
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model_name_or_path="model")
data = ["sentence1", "sentence2"]
embedding_backend.embed(data=data, normalize_embeddings=True)
embedding_backend.model.encode.assert_called_once_with(data, normalize_embeddings=True)