mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 18:29:32 +00:00
test: Adding unit tests to Ranker (#5167)
* adding unit tests for sentence transformers ranker * Adding more unit tests * Remove empty line * Undo static method * Revert change * Updated indentation and added match message * Remove unneeded paranthesis
This commit is contained in:
parent
cfd703fa3e
commit
1602f3abdd
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
import math
|
||||
from unittest.mock import patch
|
||||
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.schema import Document
|
||||
@ -7,6 +8,63 @@ from haystack.nodes.ranker.base import BaseRanker
|
||||
from haystack.nodes.ranker.sentence_transformers import SentenceTransformersRanker
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ranker_preprocess_batch_queries_and_docs_raises():
|
||||
query_1 = "query 1"
|
||||
query_2 = "query 2"
|
||||
docs = [Document(content="dummy doc 1")]
|
||||
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||
mock_ranker_init.return_value = None
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||
with pytest.raises(HaystackError, match="Number of queries must be 1 if a single list of Documents is provided."):
|
||||
_, _, _, _ = ranker._preprocess_batch_queries_and_docs(queries=[query_1, query_2], documents=docs)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ranker_preprocess_batch_queries_and_docs_single_query_single_doc_list():
|
||||
query1 = "query 1"
|
||||
docs1 = [Document(content="dummy doc 1"), Document(content="dummy doc 2")]
|
||||
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||
mock_ranker_init.return_value = None
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||
num_of_docs, all_queries, all_docs, single_list_of_docs = ranker._preprocess_batch_queries_and_docs(
|
||||
queries=[query1], documents=docs1
|
||||
)
|
||||
assert single_list_of_docs is True
|
||||
assert num_of_docs == [2]
|
||||
assert len(all_queries) == 2
|
||||
assert len(all_docs) == 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ranker_preprocess_batch_queries_and_docs_multiple_queries_multiple_doc_lists():
|
||||
query_1 = "query 1"
|
||||
query_2 = "query 2"
|
||||
docs1 = [Document(content="dummy doc 1"), Document(content="dummy doc 2")]
|
||||
docs2 = [Document(content="dummy doc 3")]
|
||||
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||
mock_ranker_init.return_value = None
|
||||
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||
num_of_docs, all_queries, all_docs, single_list_of_docs = ranker._preprocess_batch_queries_and_docs(
|
||||
queries=[query_1, query_2], documents=[docs1, docs2]
|
||||
)
|
||||
assert single_list_of_docs is False
|
||||
assert num_of_docs == [2, 1]
|
||||
assert len(all_queries) == 3
|
||||
assert len(all_docs) == 3
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_ranker_get_batches():
|
||||
all_queries = ["query 1", "query 1"]
|
||||
all_docs = [Document(content="dummy doc 1"), Document(content="dummy doc 2")]
|
||||
batches = SentenceTransformersRanker._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=None)
|
||||
assert next(batches) == (all_queries, all_docs)
|
||||
|
||||
batches = SentenceTransformersRanker._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=1)
|
||||
assert next(batches) == (all_queries[0:1], all_docs[0:1])
|
||||
|
||||
|
||||
def test_ranker(ranker):
|
||||
query = "What is the most important building in King's Landing that has a religious background?"
|
||||
docs = [
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user