test: Adding unit tests to Ranker (#5167)

* adding unit tests for sentence transformers ranker

* Adding more unit tests

* Remove empty line

* Undo static method

* Revert change

* Updated indentation and added match message

* Remove unneeded paranthesis
This commit is contained in:
Sebastian 2023-06-22 15:23:23 +02:00 committed by GitHub
parent cfd703fa3e
commit 1602f3abdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,6 @@
import pytest
import math
from unittest.mock import patch
from haystack.errors import HaystackError
from haystack.schema import Document
@ -7,6 +8,63 @@ from haystack.nodes.ranker.base import BaseRanker
from haystack.nodes.ranker.sentence_transformers import SentenceTransformersRanker
@pytest.mark.unit
def test_ranker_preprocess_batch_queries_and_docs_raises():
query_1 = "query 1"
query_2 = "query 2"
docs = [Document(content="dummy doc 1")]
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
mock_ranker_init.return_value = None
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
with pytest.raises(HaystackError, match="Number of queries must be 1 if a single list of Documents is provided."):
_, _, _, _ = ranker._preprocess_batch_queries_and_docs(queries=[query_1, query_2], documents=docs)
@pytest.mark.unit
def test_ranker_preprocess_batch_queries_and_docs_single_query_single_doc_list():
query1 = "query 1"
docs1 = [Document(content="dummy doc 1"), Document(content="dummy doc 2")]
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
mock_ranker_init.return_value = None
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
num_of_docs, all_queries, all_docs, single_list_of_docs = ranker._preprocess_batch_queries_and_docs(
queries=[query1], documents=docs1
)
assert single_list_of_docs is True
assert num_of_docs == [2]
assert len(all_queries) == 2
assert len(all_docs) == 2
@pytest.mark.unit
def test_ranker_preprocess_batch_queries_and_docs_multiple_queries_multiple_doc_lists():
query_1 = "query 1"
query_2 = "query 2"
docs1 = [Document(content="dummy doc 1"), Document(content="dummy doc 2")]
docs2 = [Document(content="dummy doc 3")]
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
mock_ranker_init.return_value = None
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
num_of_docs, all_queries, all_docs, single_list_of_docs = ranker._preprocess_batch_queries_and_docs(
queries=[query_1, query_2], documents=[docs1, docs2]
)
assert single_list_of_docs is False
assert num_of_docs == [2, 1]
assert len(all_queries) == 3
assert len(all_docs) == 3
@pytest.mark.unit
def test_ranker_get_batches():
all_queries = ["query 1", "query 1"]
all_docs = [Document(content="dummy doc 1"), Document(content="dummy doc 2")]
batches = SentenceTransformersRanker._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=None)
assert next(batches) == (all_queries, all_docs)
batches = SentenceTransformersRanker._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=1)
assert next(batches) == (all_queries[0:1], all_docs[0:1])
def test_ranker(ranker):
query = "What is the most important building in King's Landing that has a religious background?"
docs = [