mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-15 17:43:55 +00:00
test: Adding unit tests to Ranker (#5167)
* adding unit tests for sentence transformers ranker * Adding more unit tests * Remove empty line * Undo static method * Revert change * Updated indentation and added match message * Remove unneeded paranthesis
This commit is contained in:
parent
cfd703fa3e
commit
1602f3abdd
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import math
|
import math
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from haystack.errors import HaystackError
|
from haystack.errors import HaystackError
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
@ -7,6 +8,63 @@ from haystack.nodes.ranker.base import BaseRanker
|
|||||||
from haystack.nodes.ranker.sentence_transformers import SentenceTransformersRanker
|
from haystack.nodes.ranker.sentence_transformers import SentenceTransformersRanker
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_ranker_preprocess_batch_queries_and_docs_raises():
|
||||||
|
query_1 = "query 1"
|
||||||
|
query_2 = "query 2"
|
||||||
|
docs = [Document(content="dummy doc 1")]
|
||||||
|
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||||
|
mock_ranker_init.return_value = None
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||||
|
with pytest.raises(HaystackError, match="Number of queries must be 1 if a single list of Documents is provided."):
|
||||||
|
_, _, _, _ = ranker._preprocess_batch_queries_and_docs(queries=[query_1, query_2], documents=docs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_ranker_preprocess_batch_queries_and_docs_single_query_single_doc_list():
|
||||||
|
query1 = "query 1"
|
||||||
|
docs1 = [Document(content="dummy doc 1"), Document(content="dummy doc 2")]
|
||||||
|
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||||
|
mock_ranker_init.return_value = None
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||||
|
num_of_docs, all_queries, all_docs, single_list_of_docs = ranker._preprocess_batch_queries_and_docs(
|
||||||
|
queries=[query1], documents=docs1
|
||||||
|
)
|
||||||
|
assert single_list_of_docs is True
|
||||||
|
assert num_of_docs == [2]
|
||||||
|
assert len(all_queries) == 2
|
||||||
|
assert len(all_docs) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_ranker_preprocess_batch_queries_and_docs_multiple_queries_multiple_doc_lists():
|
||||||
|
query_1 = "query 1"
|
||||||
|
query_2 = "query 2"
|
||||||
|
docs1 = [Document(content="dummy doc 1"), Document(content="dummy doc 2")]
|
||||||
|
docs2 = [Document(content="dummy doc 3")]
|
||||||
|
with patch("haystack.nodes.ranker.sentence_transformers.SentenceTransformersRanker.__init__") as mock_ranker_init:
|
||||||
|
mock_ranker_init.return_value = None
|
||||||
|
ranker = SentenceTransformersRanker(model_name_or_path="fake_model")
|
||||||
|
num_of_docs, all_queries, all_docs, single_list_of_docs = ranker._preprocess_batch_queries_and_docs(
|
||||||
|
queries=[query_1, query_2], documents=[docs1, docs2]
|
||||||
|
)
|
||||||
|
assert single_list_of_docs is False
|
||||||
|
assert num_of_docs == [2, 1]
|
||||||
|
assert len(all_queries) == 3
|
||||||
|
assert len(all_docs) == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_ranker_get_batches():
|
||||||
|
all_queries = ["query 1", "query 1"]
|
||||||
|
all_docs = [Document(content="dummy doc 1"), Document(content="dummy doc 2")]
|
||||||
|
batches = SentenceTransformersRanker._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=None)
|
||||||
|
assert next(batches) == (all_queries, all_docs)
|
||||||
|
|
||||||
|
batches = SentenceTransformersRanker._get_batches(all_queries=all_queries, all_docs=all_docs, batch_size=1)
|
||||||
|
assert next(batches) == (all_queries[0:1], all_docs[0:1])
|
||||||
|
|
||||||
|
|
||||||
def test_ranker(ranker):
|
def test_ranker(ranker):
|
||||||
query = "What is the most important building in King's Landing that has a religious background?"
|
query = "What is the most important building in King's Landing that has a religious background?"
|
||||||
docs = [
|
docs = [
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user