diff --git a/test/nodes/test_ranker.py b/test/nodes/test_ranker.py index b2b79b4f0..bf1ddcbdd 100644 --- a/test/nodes/test_ranker.py +++ b/test/nodes/test_ranker.py @@ -1,5 +1,6 @@ import pytest import math +from unittest.mock import patch from haystack.errors import HaystackError from haystack.schema import Document @@ -7,6 +8,63 @@ from haystack.nodes.ranker.base import BaseRanker from haystack.nodes.ranker.sentence_transformers import SentenceTransformersRanker +@pytest.mark.unit +def test_ranker_preprocess_batch_queries_and_docs_raises(): + query_1 = "query 1" + query_2 = "query 2" + docs = [Document(content="dummy doc 1")] + with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init: + mock_ranker_init.return_value = None + ranker = SentenceTransformersRanker(model_name_or_path="fake_model") + with pytest.raises(HaystackError, match="Number of queries must be 1 if a single list of Documents is provided."): + _, _, _, _ = ranker._preprocess_batch_queries_and_docs(queries=[query_1, query_2], documents=docs) + + +@pytest.mark.unit +def test_ranker_preprocess_batch_queries_and_docs_single_query_single_doc_list(): + query1 = "query 1" + docs1 = [Document(content="dummy doc 1"), Document(content="dummy doc 2")] + with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init: + mock_ranker_init.return_value = None + ranker = SentenceTransformersRanker(model_name_or_path="fake_model") + num_of_docs, all_queries, all_docs, single_list_of_docs = ranker._preprocess_batch_queries_and_docs( + queries=[query1], documents=docs1 + ) + assert single_list_of_docs is True + assert num_of_docs == [2] + assert len(all_queries) == 2 + assert len(all_docs) == 2 + + +@pytest.mark.unit +def test_ranker_preprocess_batch_queries_and_docs_multiple_queries_multiple_doc_lists(): + query_1 = "query 1" + query_2 = "query 2" + docs1 = [Document(content="dummy doc 1"), Document(content="dummy doc 2")] + docs2 = [Document(content="dummy doc 3")] + with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init: + mock_ranker_init.return_value = None + ranker = SentenceTransformersRanker(model_name_or_path="fake_model") + num_of_docs, all_queries, all_docs, single_list_of_docs = ranker._preprocess_batch_queries_and_docs( + queries=[query_1, query_2], documents=[docs1, docs2] + ) + assert single_list_of_docs is False + assert num_of_docs == [2, 1] + assert len(all_queries) == 3 + assert len(all_docs) == 3 + + +@pytest.mark.unit +def test_ranker_get_batches(): + all_queries = ["query 1", "query 1"] + all_docs = [Document(content="dummy doc 1"), Document(content="dummy doc 2")] + batches = SentenceTransformersRanker._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=None) + assert next(batches) == (all_queries, all_docs) + + batches = SentenceTransformersRanker._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=1) + assert next(batches) == (all_queries[0:1], all_docs[0:1]) + + def test_ranker(ranker): query = "What is the most important building in King's Landing that has a religious background?" docs = [