mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-04 05:43:29 +00:00
feat: RecentnessRanker (#5301)
* recency reranker code * removed * readd * edited code * edit * mypy test fix * adding warnings for score method * fix * fix * adding paper link * comments implementation * change to predict and predict_batch * change to predict and predict_batch 2 * adding unit test * fixes * small fixes * fix for unit test * table driven test * small fixes * small fixes2 * adding predict_batch tests * add recentness_ranker to api reference docs * implementing feedback * implementing feedback2 * implementing feedback3 * implementing feedback4 * implementing feedback5 * remove document_map, remove final check if score is not None * add final check if doc score is not None for mypy --------- Co-authored-by: Darja Fokina <daria.f93@gmail.com> Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
c2506866bd
commit
612c6779fb
@ -1,7 +1,7 @@
|
|||||||
loaders:
|
loaders:
|
||||||
- type: python
|
- type: python
|
||||||
search_path: [../../../haystack/nodes/ranker]
|
search_path: [../../../haystack/nodes/ranker]
|
||||||
modules: ["base", "sentence_transformers"]
|
modules: ["base", "sentence_transformers", "recentness_ranker"]
|
||||||
ignore_when_discovered: ["__init__"]
|
ignore_when_discovered: ["__init__"]
|
||||||
processors:
|
processors:
|
||||||
- type: filter
|
- type: filter
|
||||||
|
188
haystack/nodes/ranker/recentness_ranker.py
Normal file
188
haystack/nodes/ranker/recentness_ranker.py
Normal file
@ -0,0 +1,188 @@
|
|||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import List, Union, Optional, Dict, Literal
|
||||||
|
|
||||||
|
from dateutil.parser import parse, ParserError
|
||||||
|
from haystack.errors import NodeError
|
||||||
|
|
||||||
|
from haystack.nodes.ranker.base import BaseRanker
|
||||||
|
from haystack.schema import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RecentnessRanker(BaseRanker):
|
||||||
|
outgoing_edges = 1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
date_meta_field: str,
|
||||||
|
weight: float = 0.5,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
ranking_mode: Literal["reciprocal_rank_fusion", "score"] = "reciprocal_rank_fusion",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This Node is used to rerank retrieved documents based on their age. Newer documents will rank higher.
|
||||||
|
The importance of recentness is parametrized through the weight parameter.
|
||||||
|
|
||||||
|
:param date_meta_field: Identifier pointing to the date field in the metadata.
|
||||||
|
This is a required parameter, since we need dates for sorting.
|
||||||
|
:param weight: in range [0,1].
|
||||||
|
0 disables sorting by age.
|
||||||
|
0.5 content and age have the same impact.
|
||||||
|
1 means sorting only by age, most recent comes first.
|
||||||
|
:param top_k: (optional) How many documents to return. If not provided, all documents will be returned.
|
||||||
|
It can make sense to have large top-k values from the initial retrievers and filter docs down in the
|
||||||
|
RecentnessRanker with this top_k parameter.
|
||||||
|
:param ranking_mode: The mode used to combine retriever and recentness. Possible values are 'reciprocal_rank_fusion' (default) and 'score'.
|
||||||
|
Make sure to use 'score' mode only with retrievers/rankers that give back OK score in range [0,1].
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.date_meta_field = date_meta_field
|
||||||
|
self.weight = weight
|
||||||
|
self.top_k = top_k
|
||||||
|
self.ranking_mode = ranking_mode
|
||||||
|
|
||||||
|
if self.weight < 0 or self.weight > 1:
|
||||||
|
raise NodeError(
|
||||||
|
"""
|
||||||
|
Param <weight> needs to be '0', '0.5' or '1' but was set to '{}'. \n
|
||||||
|
Please change param <weight> when initializing the RecentnessRanker.
|
||||||
|
""".format(
|
||||||
|
self.weight
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# pylint: disable=arguments-differ
|
||||||
|
def predict( # type: ignore
|
||||||
|
self, query: str, documents: List[Document], top_k: Optional[int] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""
|
||||||
|
This method is used to rank a list of documents based on their age and relevance by:
|
||||||
|
1. Adjusting the relevance score from the previous node (or, for RRF, calculating it from scratch, then adjusting) based on the chosen weight in initial parameters.
|
||||||
|
2. Sorting the documents based on their age in the metadata, calculating the recentness score, adjusting it by weight as well.
|
||||||
|
3. Returning top-k documents (or all, if top-k not provided) in the documents dictionary sorted by final score (relevance score + recentness score).
|
||||||
|
|
||||||
|
:param query: Not used in practice (so can be left blank), as this ranker does not perform sorting based on semantic closeness of documents to the query.
|
||||||
|
:param documents: Documents provided for ranking.
|
||||||
|
:param top_k: (optional) How many documents to return at the end. If not provided, all documents will be returned, sorted by relevance and recentness (adjusted by weight).
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
sorted_by_date = sorted(documents, reverse=True, key=lambda x: parse(x.meta[self.date_meta_field]))
|
||||||
|
except KeyError:
|
||||||
|
raise NodeError(
|
||||||
|
"""
|
||||||
|
Param <date_meta_field> was set to '{}', but document(s) {} do not contain this metadata key.\n
|
||||||
|
Please double-check the names of existing metadata fields of your documents \n
|
||||||
|
and set <date_meta_field> to the name of the field that contains dates.
|
||||||
|
""".format(
|
||||||
|
self.date_meta_field,
|
||||||
|
",".join([doc.id for doc in documents if self.date_meta_field not in doc.meta]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except ParserError:
|
||||||
|
logger.error(
|
||||||
|
"""
|
||||||
|
Could not parse date information for dates: %s\n
|
||||||
|
Continuing without sorting by date.
|
||||||
|
""",
|
||||||
|
" - ".join([doc.meta.get(self.date_meta_field, "identifier wrong") for doc in documents]),
|
||||||
|
)
|
||||||
|
|
||||||
|
return documents
|
||||||
|
|
||||||
|
# merge scores for documents sorted both by content and by date.
|
||||||
|
# If ranking mode is set to 'reciprocal_rank_fusion', then that is used to combine previous ranking with recency ranking.
|
||||||
|
# If ranking mode is set to 'score', then documents will be assigned a recency score in [0,1] and will be re-ranked based on both their recency score and their pre-existing relevance score.
|
||||||
|
scores_map: Dict = defaultdict(int)
|
||||||
|
if self.ranking_mode not in ["reciprocal_rank_fusion", "score"]:
|
||||||
|
raise NodeError(
|
||||||
|
"""
|
||||||
|
Param <ranking_mode> needs to be 'reciprocal_rank_fusion' or 'score' but was set to '{}'. \n
|
||||||
|
Please change the <ranking_mode> when initializing the RecentnessRanker.
|
||||||
|
""".format(
|
||||||
|
self.ranking_mode
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, doc in enumerate(documents):
|
||||||
|
if self.ranking_mode == "reciprocal_rank_fusion":
|
||||||
|
scores_map[doc.id] += self._calculate_rrf(rank=i) * (1 - self.weight)
|
||||||
|
elif self.ranking_mode == "score":
|
||||||
|
score = float(0)
|
||||||
|
if doc.score is None:
|
||||||
|
warnings.warn("The score was not provided; defaulting to 0")
|
||||||
|
elif doc.score < 0 or doc.score > 1:
|
||||||
|
warnings.warn(
|
||||||
|
"The score {} for document {} is outside the [0,1] range; defaulting to 0".format(
|
||||||
|
doc.score, doc.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
score = doc.score
|
||||||
|
|
||||||
|
scores_map[doc.id] += score * (1 - self.weight)
|
||||||
|
|
||||||
|
for i, doc in enumerate(sorted_by_date):
|
||||||
|
if self.ranking_mode == "reciprocal_rank_fusion":
|
||||||
|
scores_map[doc.id] += self._calculate_rrf(rank=i) * self.weight
|
||||||
|
elif self.ranking_mode == "score":
|
||||||
|
scores_map[doc.id] += self._calc_recentness_score(rank=i, amount=len(sorted_by_date)) * self.weight
|
||||||
|
|
||||||
|
top_k = top_k or self.top_k or len(documents)
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
doc.score = scores_map[doc.id]
|
||||||
|
|
||||||
|
return sorted(documents, key=lambda doc: doc.score if doc.score is not None else -1, reverse=True)[:top_k]
|
||||||
|
|
||||||
|
# pylint: disable=arguments-differ
|
||||||
|
def predict_batch( # type: ignore
|
||||||
|
self,
|
||||||
|
queries: List[str],
|
||||||
|
documents: Union[List[Document], List[List[Document]]],
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
) -> Union[List[Document], List[List[Document]]]:
|
||||||
|
"""
|
||||||
|
This method is used to rank A) a list or B) a list of lists (in case the previous node is JoinDocuments) of documents based on their age and relevance.
|
||||||
|
In case A, the predict method defined earlier is applied to the provided list.
|
||||||
|
In case B, predict method is applied to each individual list in the list of lists provided, then the results are returned as list of lists.
|
||||||
|
|
||||||
|
:param queries: Not used in practice (so can be left blank), as this ranker does not perform sorting based on semantic closeness of documents to the query.
|
||||||
|
:param documents: Documents provided for ranking in a list or a list of lists.
|
||||||
|
:param top_k: (optional) How many documents to return at the end (per list). If not provided, all documents will be returned, sorted by relevance and recentness (adjusted by weight).
|
||||||
|
:param batch_size: Not used in practice, so can be left blank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(documents[0], Document):
|
||||||
|
return self.predict("", documents=documents, top_k=top_k) # type: ignore
|
||||||
|
nested_docs = []
|
||||||
|
for docs in documents:
|
||||||
|
results = self.predict("", documents=docs, top_k=top_k) # type: ignore
|
||||||
|
nested_docs.append(results)
|
||||||
|
|
||||||
|
return nested_docs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calculate_rrf(rank: int, k: int = 61) -> float:
|
||||||
|
"""
|
||||||
|
Calculates the reciprocal rank fusion. The constant K is set to 61 (60 was suggested by the original paper,
|
||||||
|
plus 1 as python lists are 0-based and the paper [https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf] used 1-based ranking).
|
||||||
|
"""
|
||||||
|
return 1 / (k + rank)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calc_recentness_score(rank: int, amount: int) -> float:
|
||||||
|
"""
|
||||||
|
Calculate recentness score as a linear score between most recent and oldest document.
|
||||||
|
This linear scaling is useful to
|
||||||
|
a) reduce the effect of outliers and
|
||||||
|
b) create recentness scoress that are meaningfully distributed in [0,1],
|
||||||
|
similar to scores coming from a retriever/ranker.
|
||||||
|
"""
|
||||||
|
return (amount - rank) / amount
|
@ -1,12 +1,16 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import math
|
import math
|
||||||
|
import warnings
|
||||||
|
import logging
|
||||||
|
import copy
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from haystack.schema import Document
|
from haystack.schema import Document
|
||||||
from haystack.nodes.ranker.base import BaseRanker
|
from haystack.nodes.ranker.base import BaseRanker
|
||||||
from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker
|
from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker
|
||||||
from haystack.errors import HaystackError
|
from haystack.nodes.ranker.recentness_ranker import RecentnessRanker
|
||||||
|
from haystack.errors import HaystackError, NodeError
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -499,3 +503,282 @@ def test_cohere_ranker_batch_multiple_queries_multiple_doc_lists(docs, mock_cohe
|
|||||||
assert isinstance(results[0], list)
|
assert isinstance(results[0], list)
|
||||||
assert results[0][0] == docs[4]
|
assert results[0][0] == docs[4]
|
||||||
assert results[1][0] == docs[4]
|
assert results[1][0] == docs[4]
|
||||||
|
|
||||||
|
|
||||||
|
recency_tests_inputs = [
|
||||||
|
# Score ranking mode works as expected
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"docs": [
|
||||||
|
{"meta": {"date": "2021-02-11"}, "score": 0.3, "id": "1"},
|
||||||
|
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
|
||||||
|
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
|
||||||
|
],
|
||||||
|
"weight": 0.5,
|
||||||
|
"date_meta_field": "date",
|
||||||
|
"top_k": 2,
|
||||||
|
"ranking_mode": "score",
|
||||||
|
"expected_scores": {"1": 0.4833333333333333, "2": 0.7},
|
||||||
|
"expected_order": ["2", "1"],
|
||||||
|
"expected_logs": [],
|
||||||
|
"expected_warning": "",
|
||||||
|
},
|
||||||
|
id="Score ranking mode works as expected",
|
||||||
|
),
|
||||||
|
# RRF ranking mode works as expected
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"docs": [
|
||||||
|
{"meta": {"date": "2021-02-11"}, "id": "1"},
|
||||||
|
{"meta": {"date": "2018-02-11"}, "id": "2"},
|
||||||
|
{"meta": {"date": "2020-02-11"}, "id": "3"},
|
||||||
|
],
|
||||||
|
"weight": 0.5,
|
||||||
|
"date_meta_field": "date",
|
||||||
|
"top_k": 2,
|
||||||
|
"ranking_mode": "reciprocal_rank_fusion",
|
||||||
|
"expected_scores": {"1": 0.01639344262295082, "2": 0.016001024065540194},
|
||||||
|
"expected_order": ["1", "2"],
|
||||||
|
"expected_logs": [],
|
||||||
|
"expected_warning": "",
|
||||||
|
},
|
||||||
|
id="RRF ranking mode works as expected",
|
||||||
|
),
|
||||||
|
# Wrong field to find the date
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"docs": [
|
||||||
|
{"meta": {"data": "2021-02-11"}, "score": 0.3, "id": "1"},
|
||||||
|
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
|
||||||
|
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
|
||||||
|
],
|
||||||
|
"weight": 0.5,
|
||||||
|
"date_meta_field": "date",
|
||||||
|
"expected_scores": {"1": 0.3, "2": 0.4, "3": 0.6},
|
||||||
|
"expected_order": ["1", "2", "3"],
|
||||||
|
"expected_exception": NodeError(
|
||||||
|
"""
|
||||||
|
Param <date_meta_field> was set to 'date', but document(s) 1 do not contain this metadata key.\n
|
||||||
|
Please double-check the names of existing metadata fields of your documents \n
|
||||||
|
and set <date_meta_field> to the name of the field that contains dates.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
"top_k": 2,
|
||||||
|
"ranking_mode": "score",
|
||||||
|
},
|
||||||
|
id="Wrong field to find the date",
|
||||||
|
),
|
||||||
|
# Date unparsable
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"docs": [
|
||||||
|
{"meta": {"date": "abcd"}, "id": "1"},
|
||||||
|
{"meta": {"date": "2024-02-11"}, "id": "2"},
|
||||||
|
{"meta": {"date": "2020-02-11"}, "id": "3"},
|
||||||
|
],
|
||||||
|
"weight": 0.5,
|
||||||
|
"date_meta_field": "date",
|
||||||
|
"expected_order": ["1", "2", "3"],
|
||||||
|
"expected_logs": [
|
||||||
|
(
|
||||||
|
"haystack.nodes.ranker.recentness_ranker",
|
||||||
|
logging.ERROR,
|
||||||
|
"""
|
||||||
|
Could not parse date information for dates: abcd - 2024-02-11 - 2020-02-11\n
|
||||||
|
Continuing without sorting by date.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"top_k": 2,
|
||||||
|
"ranking_mode": "reciprocal_rank_fusion",
|
||||||
|
},
|
||||||
|
id="Date unparsable",
|
||||||
|
),
|
||||||
|
# Wrong score, outside of bonds
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"docs": [
|
||||||
|
{"meta": {"date": "2021-02-11"}, "score": 1.3, "id": "1"},
|
||||||
|
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
|
||||||
|
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
|
||||||
|
],
|
||||||
|
"weight": 0.5,
|
||||||
|
"date_meta_field": "date",
|
||||||
|
"top_k": 2,
|
||||||
|
"ranking_mode": "score",
|
||||||
|
"expected_scores": {"1": 0.5, "2": 0.7, "3": 0.4666666666666667},
|
||||||
|
"expected_order": ["2", "3"],
|
||||||
|
"expected_warning": ["The score 1.3 for document 1 is outside the [0,1] range; defaulting to 0"],
|
||||||
|
},
|
||||||
|
id="Wrong score, outside of bonds",
|
||||||
|
),
|
||||||
|
# Wrong score, not provided
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"docs": [
|
||||||
|
{"meta": {"date": "2021-02-11"}, "id": "1"},
|
||||||
|
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
|
||||||
|
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
|
||||||
|
],
|
||||||
|
"weight": 0.5,
|
||||||
|
"date_meta_field": "date",
|
||||||
|
"top_k": 2,
|
||||||
|
"ranking_mode": "score",
|
||||||
|
"expected_scores": {"1": 0.5, "2": 0.7, "3": 0.4666666666666667},
|
||||||
|
"expected_order": ["2", "3"],
|
||||||
|
"expected_warning": ["The score was not provided; defaulting to 0"],
|
||||||
|
},
|
||||||
|
id="Wrong score, not provided",
|
||||||
|
),
|
||||||
|
# Wrong ranking mode provided
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"docs": [
|
||||||
|
{"meta": {"date": "2021-02-11"}, "id": "1"},
|
||||||
|
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
|
||||||
|
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
|
||||||
|
],
|
||||||
|
"weight": 0.5,
|
||||||
|
"date_meta_field": "date",
|
||||||
|
"top_k": 2,
|
||||||
|
"ranking_mode": "blablabla",
|
||||||
|
"expected_scores": {"1": 0.01626123744050767, "2": 0.01626123744050767},
|
||||||
|
"expected_order": ["1", "2"],
|
||||||
|
"expected_exception": NodeError(
|
||||||
|
"""
|
||||||
|
Param <ranking_mode> needs to be 'reciprocal_rank_fusion' or 'score' but was set to 'blablabla'. \n
|
||||||
|
Please change the <ranking_mode> when initializing the RecentnessRanker.
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
},
|
||||||
|
id="Wrong ranking mode provided",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.parametrize("test_input", recency_tests_inputs)
|
||||||
|
def test_recentness_ranker(caplog, test_input):
|
||||||
|
# Create a set of docs
|
||||||
|
docs = []
|
||||||
|
for doc in test_input["docs"]:
|
||||||
|
docs.append(Document(content="abc", **doc))
|
||||||
|
|
||||||
|
# catch warnings to check they are properly issued
|
||||||
|
with warnings.catch_warnings(record=True) as warnings_list:
|
||||||
|
# Initialize the ranker
|
||||||
|
ranker = RecentnessRanker(
|
||||||
|
date_meta_field=test_input["date_meta_field"],
|
||||||
|
ranking_mode=test_input["ranking_mode"],
|
||||||
|
weight=test_input["weight"],
|
||||||
|
)
|
||||||
|
predict_exception = None
|
||||||
|
results = []
|
||||||
|
try:
|
||||||
|
results = ranker.predict(query="", documents=docs, top_k=test_input["top_k"])
|
||||||
|
except Exception as e:
|
||||||
|
predict_exception = e
|
||||||
|
|
||||||
|
check_results(results, test_input, warnings_list, caplog, predict_exception)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.parametrize("test_input", recency_tests_inputs)
|
||||||
|
def test_recentness_ranker_batch_list(caplog, test_input):
|
||||||
|
# Create a set of docs
|
||||||
|
docs = []
|
||||||
|
for doc in test_input["docs"]:
|
||||||
|
docs.append(Document(content="abc", **doc))
|
||||||
|
|
||||||
|
# catch warnings to check they are properly issued
|
||||||
|
with warnings.catch_warnings(record=True) as warnings_list:
|
||||||
|
# Initialize the ranker
|
||||||
|
ranker = RecentnessRanker(
|
||||||
|
date_meta_field=test_input["date_meta_field"],
|
||||||
|
ranking_mode=test_input["ranking_mode"],
|
||||||
|
weight=test_input["weight"],
|
||||||
|
)
|
||||||
|
predict_exception = None
|
||||||
|
results = []
|
||||||
|
try:
|
||||||
|
# Run predict_batch with a list as input
|
||||||
|
results = ranker.predict_batch(queries="", documents=docs, top_k=test_input["top_k"])
|
||||||
|
except Exception as e:
|
||||||
|
predict_exception = e
|
||||||
|
|
||||||
|
check_results(results, test_input, warnings_list, caplog, predict_exception)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
@pytest.mark.parametrize("test_input", recency_tests_inputs)
|
||||||
|
def test_recentness_ranker_batch_list_of_lists(caplog, test_input):
|
||||||
|
# Create a set of docs
|
||||||
|
docs = []
|
||||||
|
for doc in test_input["docs"]:
|
||||||
|
docs.append(Document(content="abc", **doc))
|
||||||
|
|
||||||
|
# catch warnings to check they are properly issued
|
||||||
|
with warnings.catch_warnings(record=True) as warnings_list:
|
||||||
|
# Initialize the ranker
|
||||||
|
ranker = RecentnessRanker(
|
||||||
|
date_meta_field=test_input["date_meta_field"],
|
||||||
|
ranking_mode=test_input["ranking_mode"],
|
||||||
|
weight=test_input["weight"],
|
||||||
|
)
|
||||||
|
|
||||||
|
predict_exception = None
|
||||||
|
results = []
|
||||||
|
try:
|
||||||
|
# Run predict_batch with a list of lists as input
|
||||||
|
results = ranker.predict_batch(queries="", documents=[docs, copy.deepcopy(docs)], top_k=test_input["top_k"])
|
||||||
|
except Exception as e:
|
||||||
|
predict_exception = e
|
||||||
|
|
||||||
|
check_results(results, test_input, warnings_list, caplog, predict_exception, list_of_lists=True)
|
||||||
|
|
||||||
|
|
||||||
|
def check_results(results, test_input, warnings_list, caplog, exception, list_of_lists=False):
|
||||||
|
expected_logs_count = 1
|
||||||
|
if list_of_lists:
|
||||||
|
expected_logs_count = 2
|
||||||
|
|
||||||
|
if "expected_exception" in test_input and test_input["expected_exception"] is not None:
|
||||||
|
assert exception.message == test_input["expected_exception"].message
|
||||||
|
assert type(exception) == type(test_input["expected_exception"])
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
assert exception is None
|
||||||
|
|
||||||
|
# Check that no warnings were thrown, if we are not expecting any
|
||||||
|
if "expected_warning" not in test_input or test_input["expected_warning"] == []:
|
||||||
|
assert len(warnings_list) == 0
|
||||||
|
# Check that all expected warnings happened, and only those
|
||||||
|
else:
|
||||||
|
assert len(warnings_list) == len(test_input["expected_warning"])
|
||||||
|
for i in range(len(warnings_list)):
|
||||||
|
assert test_input["expected_warning"][int(i)] == str(warnings_list[i].message)
|
||||||
|
|
||||||
|
# If we expect logging, compare them one by one
|
||||||
|
if "expected_logs" not in test_input or test_input["expected_logs"] == []:
|
||||||
|
assert len(caplog.record_tuples) == 0
|
||||||
|
else:
|
||||||
|
assert expected_logs_count * len(test_input["expected_logs"]) == len(caplog.record_tuples)
|
||||||
|
for i in range(len(caplog.record_tuples)):
|
||||||
|
assert test_input["expected_logs"][int(i / expected_logs_count)] == caplog.record_tuples[i]
|
||||||
|
|
||||||
|
if not list_of_lists:
|
||||||
|
check_result_content(results, test_input)
|
||||||
|
else:
|
||||||
|
for i in results:
|
||||||
|
check_result_content(i, test_input)
|
||||||
|
|
||||||
|
|
||||||
|
# Verify the results, that the order and the score of the documents match
|
||||||
|
def check_result_content(results, test_input):
|
||||||
|
assert len(results) == len(test_input["expected_order"])
|
||||||
|
for i in range(len(test_input["expected_order"])):
|
||||||
|
assert test_input["expected_order"][i] == results[i].id
|
||||||
|
if "expected_scores" in test_input:
|
||||||
|
assert test_input["expected_scores"][results[i].id] == results[i].score
|
||||||
|
else:
|
||||||
|
assert results[i].score is None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user