haystack/test/components/embedders/test_sentence_transformers_embedding_backend.py
ZanSara 288ed150c9
feat!: Rename model_name or model_name_or_path to model in all Embedder classes (#6733)
* rename model parameter in the openai doc embedder

* fix tests for openai doc embedder

* rename model parameter in the openai text embedder

* fix tests for openai text embedder

* rename model parameter in the st doc embedder

* fix tests for st doc embedder

* rename model parameter in the st backend

* fix tests for st backend

* rename model parameter in the st text embedder

* fix tests for st text embedder

* fix docstring

* fix pipeline utils

* fix e2e

* reno

* fix the indexing pipeline _create_embedder function

* fix e2e eval rag pipeline

* pytest
2024-01-12 15:30:17 +01:00

40 lines
1.7 KiB
Python

from unittest.mock import patch
import pytest
from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_factory_behavior(mock_sentence_transformer):
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="my_model", device="cpu"
)
same_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend("my_model", "cpu")
another_embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="another_model", device="cpu"
)
assert same_embedding_backend is embedding_backend
assert another_embedding_backend is not embedding_backend
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_model_initialization(mock_sentence_transformer):
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="model", device="cpu", use_auth_token="my_token"
)
mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model", device="cpu", use_auth_token="my_token"
)
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_embedding_function_with_kwargs(mock_sentence_transformer):
embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(model="model")
data = ["sentence1", "sentence2"]
embedding_backend.embed(data=data, normalize_embeddings=True)
embedding_backend.model.encode.assert_called_once_with(data, normalize_embeddings=True)