mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-03 21:33:40 +00:00
feat: RecentnessRanker (#5301)
* recency reranker code * removed * readd * edited code * edit * mypy test fix * adding warnings for score method * fix * fix * adding paper link * comments implementation * change to predict and predict_batch * change to predict and predict_batch 2 * adding unit test * fixes * small fixes * fix for unit test * table driven test * small fixes * small fixes2 * adding predict_batch tests * add recentness_ranker to api reference docs * implementing feedback * implementing feedback2 * implementing feedback3 * implementing feedback4 * implementing feedback5 * remove document_map, remove final check if score is not None * add final check if doc score is not None for mypy --------- Co-authored-by: Darja Fokina <daria.f93@gmail.com> Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
c2506866bd
commit
612c6779fb
@ -1,7 +1,7 @@
|
||||
loaders:
|
||||
- type: python
|
||||
search_path: [../../../haystack/nodes/ranker]
|
||||
modules: ["base", "sentence_transformers"]
|
||||
modules: ["base", "sentence_transformers", "recentness_ranker"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
|
188
haystack/nodes/ranker/recentness_ranker.py
Normal file
188
haystack/nodes/ranker/recentness_ranker.py
Normal file
@ -0,0 +1,188 @@
|
||||
import logging
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import List, Union, Optional, Dict, Literal
|
||||
|
||||
from dateutil.parser import parse, ParserError
|
||||
from haystack.errors import NodeError
|
||||
|
||||
from haystack.nodes.ranker.base import BaseRanker
|
||||
from haystack.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RecentnessRanker(BaseRanker):
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
date_meta_field: str,
|
||||
weight: float = 0.5,
|
||||
top_k: Optional[int] = None,
|
||||
ranking_mode: Literal["reciprocal_rank_fusion", "score"] = "reciprocal_rank_fusion",
|
||||
):
|
||||
"""
|
||||
This Node is used to rerank retrieved documents based on their age. Newer documents will rank higher.
|
||||
The importance of recentness is parametrized through the weight parameter.
|
||||
|
||||
:param date_meta_field: Identifier pointing to the date field in the metadata.
|
||||
This is a required parameter, since we need dates for sorting.
|
||||
:param weight: in range [0,1].
|
||||
0 disables sorting by age.
|
||||
0.5 content and age have the same impact.
|
||||
1 means sorting only by age, most recent comes first.
|
||||
:param top_k: (optional) How many documents to return. If not provided, all documents will be returned.
|
||||
It can make sense to have large top-k values from the initial retrievers and filter docs down in the
|
||||
RecentnessRanker with this top_k parameter.
|
||||
:param ranking_mode: The mode used to combine retriever and recentness. Possible values are 'reciprocal_rank_fusion' (default) and 'score'.
|
||||
Make sure to use 'score' mode only with retrievers/rankers that give back OK score in range [0,1].
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.date_meta_field = date_meta_field
|
||||
self.weight = weight
|
||||
self.top_k = top_k
|
||||
self.ranking_mode = ranking_mode
|
||||
|
||||
if self.weight < 0 or self.weight > 1:
|
||||
raise NodeError(
|
||||
"""
|
||||
Param <weight> needs to be '0', '0.5' or '1' but was set to '{}'. \n
|
||||
Please change param <weight> when initializing the RecentnessRanker.
|
||||
""".format(
|
||||
self.weight
|
||||
)
|
||||
)
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def predict( # type: ignore
|
||||
self, query: str, documents: List[Document], top_k: Optional[int] = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
This method is used to rank a list of documents based on their age and relevance by:
|
||||
1. Adjusting the relevance score from the previous node (or, for RRF, calculating it from scratch, then adjusting) based on the chosen weight in initial parameters.
|
||||
2. Sorting the documents based on their age in the metadata, calculating the recentness score, adjusting it by weight as well.
|
||||
3. Returning top-k documents (or all, if top-k not provided) in the documents dictionary sorted by final score (relevance score + recentness score).
|
||||
|
||||
:param query: Not used in practice (so can be left blank), as this ranker does not perform sorting based on semantic closeness of documents to the query.
|
||||
:param documents: Documents provided for ranking.
|
||||
:param top_k: (optional) How many documents to return at the end. If not provided, all documents will be returned, sorted by relevance and recentness (adjusted by weight).
|
||||
"""
|
||||
|
||||
try:
|
||||
sorted_by_date = sorted(documents, reverse=True, key=lambda x: parse(x.meta[self.date_meta_field]))
|
||||
except KeyError:
|
||||
raise NodeError(
|
||||
"""
|
||||
Param <date_meta_field> was set to '{}', but document(s) {} do not contain this metadata key.\n
|
||||
Please double-check the names of existing metadata fields of your documents \n
|
||||
and set <date_meta_field> to the name of the field that contains dates.
|
||||
""".format(
|
||||
self.date_meta_field,
|
||||
",".join([doc.id for doc in documents if self.date_meta_field not in doc.meta]),
|
||||
)
|
||||
)
|
||||
|
||||
except ParserError:
|
||||
logger.error(
|
||||
"""
|
||||
Could not parse date information for dates: %s\n
|
||||
Continuing without sorting by date.
|
||||
""",
|
||||
" - ".join([doc.meta.get(self.date_meta_field, "identifier wrong") for doc in documents]),
|
||||
)
|
||||
|
||||
return documents
|
||||
|
||||
# merge scores for documents sorted both by content and by date.
|
||||
# If ranking mode is set to 'reciprocal_rank_fusion', then that is used to combine previous ranking with recency ranking.
|
||||
# If ranking mode is set to 'score', then documents will be assigned a recency score in [0,1] and will be re-ranked based on both their recency score and their pre-existing relevance score.
|
||||
scores_map: Dict = defaultdict(int)
|
||||
if self.ranking_mode not in ["reciprocal_rank_fusion", "score"]:
|
||||
raise NodeError(
|
||||
"""
|
||||
Param <ranking_mode> needs to be 'reciprocal_rank_fusion' or 'score' but was set to '{}'. \n
|
||||
Please change the <ranking_mode> when initializing the RecentnessRanker.
|
||||
""".format(
|
||||
self.ranking_mode
|
||||
)
|
||||
)
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
if self.ranking_mode == "reciprocal_rank_fusion":
|
||||
scores_map[doc.id] += self._calculate_rrf(rank=i) * (1 - self.weight)
|
||||
elif self.ranking_mode == "score":
|
||||
score = float(0)
|
||||
if doc.score is None:
|
||||
warnings.warn("The score was not provided; defaulting to 0")
|
||||
elif doc.score < 0 or doc.score > 1:
|
||||
warnings.warn(
|
||||
"The score {} for document {} is outside the [0,1] range; defaulting to 0".format(
|
||||
doc.score, doc.id
|
||||
)
|
||||
)
|
||||
else:
|
||||
score = doc.score
|
||||
|
||||
scores_map[doc.id] += score * (1 - self.weight)
|
||||
|
||||
for i, doc in enumerate(sorted_by_date):
|
||||
if self.ranking_mode == "reciprocal_rank_fusion":
|
||||
scores_map[doc.id] += self._calculate_rrf(rank=i) * self.weight
|
||||
elif self.ranking_mode == "score":
|
||||
scores_map[doc.id] += self._calc_recentness_score(rank=i, amount=len(sorted_by_date)) * self.weight
|
||||
|
||||
top_k = top_k or self.top_k or len(documents)
|
||||
|
||||
for doc in documents:
|
||||
doc.score = scores_map[doc.id]
|
||||
|
||||
return sorted(documents, key=lambda doc: doc.score if doc.score is not None else -1, reverse=True)[:top_k]
|
||||
|
||||
# pylint: disable=arguments-differ
|
||||
def predict_batch( # type: ignore
|
||||
self,
|
||||
queries: List[str],
|
||||
documents: Union[List[Document], List[List[Document]]],
|
||||
top_k: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Union[List[Document], List[List[Document]]]:
|
||||
"""
|
||||
This method is used to rank A) a list or B) a list of lists (in case the previous node is JoinDocuments) of documents based on their age and relevance.
|
||||
In case A, the predict method defined earlier is applied to the provided list.
|
||||
In case B, predict method is applied to each individual list in the list of lists provided, then the results are returned as list of lists.
|
||||
|
||||
:param queries: Not used in practice (so can be left blank), as this ranker does not perform sorting based on semantic closeness of documents to the query.
|
||||
:param documents: Documents provided for ranking in a list or a list of lists.
|
||||
:param top_k: (optional) How many documents to return at the end (per list). If not provided, all documents will be returned, sorted by relevance and recentness (adjusted by weight).
|
||||
:param batch_size: Not used in practice, so can be left blank.
|
||||
"""
|
||||
|
||||
if isinstance(documents[0], Document):
|
||||
return self.predict("", documents=documents, top_k=top_k) # type: ignore
|
||||
nested_docs = []
|
||||
for docs in documents:
|
||||
results = self.predict("", documents=docs, top_k=top_k) # type: ignore
|
||||
nested_docs.append(results)
|
||||
|
||||
return nested_docs
|
||||
|
||||
@staticmethod
|
||||
def _calculate_rrf(rank: int, k: int = 61) -> float:
|
||||
"""
|
||||
Calculates the reciprocal rank fusion. The constant K is set to 61 (60 was suggested by the original paper,
|
||||
plus 1 as python lists are 0-based and the paper [https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf] used 1-based ranking).
|
||||
"""
|
||||
return 1 / (k + rank)
|
||||
|
||||
@staticmethod
|
||||
def _calc_recentness_score(rank: int, amount: int) -> float:
|
||||
"""
|
||||
Calculate recentness score as a linear score between most recent and oldest document.
|
||||
This linear scaling is useful to
|
||||
a) reduce the effect of outliers and
|
||||
b) create recentness scoress that are meaningfully distributed in [0,1],
|
||||
similar to scores coming from a retriever/ranker.
|
||||
"""
|
||||
return (amount - rank) / amount
|
@ -1,12 +1,16 @@
|
||||
import pytest
|
||||
import math
|
||||
import warnings
|
||||
import logging
|
||||
import copy
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from haystack.schema import Document
|
||||
from haystack.nodes.ranker.base import BaseRanker
|
||||
from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker
|
||||
from haystack.errors import HaystackError
|
||||
from haystack.nodes.ranker.recentness_ranker import RecentnessRanker
|
||||
from haystack.errors import HaystackError, NodeError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -499,3 +503,282 @@ def test_cohere_ranker_batch_multiple_queries_multiple_doc_lists(docs, mock_cohe
|
||||
assert isinstance(results[0], list)
|
||||
assert results[0][0] == docs[4]
|
||||
assert results[1][0] == docs[4]
|
||||
|
||||
|
||||
recency_tests_inputs = [
|
||||
# Score ranking mode works as expected
|
||||
pytest.param(
|
||||
{
|
||||
"docs": [
|
||||
{"meta": {"date": "2021-02-11"}, "score": 0.3, "id": "1"},
|
||||
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
|
||||
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
|
||||
],
|
||||
"weight": 0.5,
|
||||
"date_meta_field": "date",
|
||||
"top_k": 2,
|
||||
"ranking_mode": "score",
|
||||
"expected_scores": {"1": 0.4833333333333333, "2": 0.7},
|
||||
"expected_order": ["2", "1"],
|
||||
"expected_logs": [],
|
||||
"expected_warning": "",
|
||||
},
|
||||
id="Score ranking mode works as expected",
|
||||
),
|
||||
# RRF ranking mode works as expected
|
||||
pytest.param(
|
||||
{
|
||||
"docs": [
|
||||
{"meta": {"date": "2021-02-11"}, "id": "1"},
|
||||
{"meta": {"date": "2018-02-11"}, "id": "2"},
|
||||
{"meta": {"date": "2020-02-11"}, "id": "3"},
|
||||
],
|
||||
"weight": 0.5,
|
||||
"date_meta_field": "date",
|
||||
"top_k": 2,
|
||||
"ranking_mode": "reciprocal_rank_fusion",
|
||||
"expected_scores": {"1": 0.01639344262295082, "2": 0.016001024065540194},
|
||||
"expected_order": ["1", "2"],
|
||||
"expected_logs": [],
|
||||
"expected_warning": "",
|
||||
},
|
||||
id="RRF ranking mode works as expected",
|
||||
),
|
||||
# Wrong field to find the date
|
||||
pytest.param(
|
||||
{
|
||||
"docs": [
|
||||
{"meta": {"data": "2021-02-11"}, "score": 0.3, "id": "1"},
|
||||
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
|
||||
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
|
||||
],
|
||||
"weight": 0.5,
|
||||
"date_meta_field": "date",
|
||||
"expected_scores": {"1": 0.3, "2": 0.4, "3": 0.6},
|
||||
"expected_order": ["1", "2", "3"],
|
||||
"expected_exception": NodeError(
|
||||
"""
|
||||
Param <date_meta_field> was set to 'date', but document(s) 1 do not contain this metadata key.\n
|
||||
Please double-check the names of existing metadata fields of your documents \n
|
||||
and set <date_meta_field> to the name of the field that contains dates.
|
||||
"""
|
||||
),
|
||||
"top_k": 2,
|
||||
"ranking_mode": "score",
|
||||
},
|
||||
id="Wrong field to find the date",
|
||||
),
|
||||
# Date unparsable
|
||||
pytest.param(
|
||||
{
|
||||
"docs": [
|
||||
{"meta": {"date": "abcd"}, "id": "1"},
|
||||
{"meta": {"date": "2024-02-11"}, "id": "2"},
|
||||
{"meta": {"date": "2020-02-11"}, "id": "3"},
|
||||
],
|
||||
"weight": 0.5,
|
||||
"date_meta_field": "date",
|
||||
"expected_order": ["1", "2", "3"],
|
||||
"expected_logs": [
|
||||
(
|
||||
"haystack.nodes.ranker.recentness_ranker",
|
||||
logging.ERROR,
|
||||
"""
|
||||
Could not parse date information for dates: abcd - 2024-02-11 - 2020-02-11\n
|
||||
Continuing without sorting by date.
|
||||
""",
|
||||
)
|
||||
],
|
||||
"top_k": 2,
|
||||
"ranking_mode": "reciprocal_rank_fusion",
|
||||
},
|
||||
id="Date unparsable",
|
||||
),
|
||||
# Wrong score, outside of bonds
|
||||
pytest.param(
|
||||
{
|
||||
"docs": [
|
||||
{"meta": {"date": "2021-02-11"}, "score": 1.3, "id": "1"},
|
||||
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
|
||||
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
|
||||
],
|
||||
"weight": 0.5,
|
||||
"date_meta_field": "date",
|
||||
"top_k": 2,
|
||||
"ranking_mode": "score",
|
||||
"expected_scores": {"1": 0.5, "2": 0.7, "3": 0.4666666666666667},
|
||||
"expected_order": ["2", "3"],
|
||||
"expected_warning": ["The score 1.3 for document 1 is outside the [0,1] range; defaulting to 0"],
|
||||
},
|
||||
id="Wrong score, outside of bonds",
|
||||
),
|
||||
# Wrong score, not provided
|
||||
pytest.param(
|
||||
{
|
||||
"docs": [
|
||||
{"meta": {"date": "2021-02-11"}, "id": "1"},
|
||||
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
|
||||
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
|
||||
],
|
||||
"weight": 0.5,
|
||||
"date_meta_field": "date",
|
||||
"top_k": 2,
|
||||
"ranking_mode": "score",
|
||||
"expected_scores": {"1": 0.5, "2": 0.7, "3": 0.4666666666666667},
|
||||
"expected_order": ["2", "3"],
|
||||
"expected_warning": ["The score was not provided; defaulting to 0"],
|
||||
},
|
||||
id="Wrong score, not provided",
|
||||
),
|
||||
# Wrong ranking mode provided
|
||||
pytest.param(
|
||||
{
|
||||
"docs": [
|
||||
{"meta": {"date": "2021-02-11"}, "id": "1"},
|
||||
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
|
||||
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
|
||||
],
|
||||
"weight": 0.5,
|
||||
"date_meta_field": "date",
|
||||
"top_k": 2,
|
||||
"ranking_mode": "blablabla",
|
||||
"expected_scores": {"1": 0.01626123744050767, "2": 0.01626123744050767},
|
||||
"expected_order": ["1", "2"],
|
||||
"expected_exception": NodeError(
|
||||
"""
|
||||
Param <ranking_mode> needs to be 'reciprocal_rank_fusion' or 'score' but was set to 'blablabla'. \n
|
||||
Please change the <ranking_mode> when initializing the RecentnessRanker.
|
||||
"""
|
||||
),
|
||||
},
|
||||
id="Wrong ranking mode provided",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("test_input", recency_tests_inputs)
|
||||
def test_recentness_ranker(caplog, test_input):
|
||||
# Create a set of docs
|
||||
docs = []
|
||||
for doc in test_input["docs"]:
|
||||
docs.append(Document(content="abc", **doc))
|
||||
|
||||
# catch warnings to check they are properly issued
|
||||
with warnings.catch_warnings(record=True) as warnings_list:
|
||||
# Initialize the ranker
|
||||
ranker = RecentnessRanker(
|
||||
date_meta_field=test_input["date_meta_field"],
|
||||
ranking_mode=test_input["ranking_mode"],
|
||||
weight=test_input["weight"],
|
||||
)
|
||||
predict_exception = None
|
||||
results = []
|
||||
try:
|
||||
results = ranker.predict(query="", documents=docs, top_k=test_input["top_k"])
|
||||
except Exception as e:
|
||||
predict_exception = e
|
||||
|
||||
check_results(results, test_input, warnings_list, caplog, predict_exception)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("test_input", recency_tests_inputs)
|
||||
def test_recentness_ranker_batch_list(caplog, test_input):
|
||||
# Create a set of docs
|
||||
docs = []
|
||||
for doc in test_input["docs"]:
|
||||
docs.append(Document(content="abc", **doc))
|
||||
|
||||
# catch warnings to check they are properly issued
|
||||
with warnings.catch_warnings(record=True) as warnings_list:
|
||||
# Initialize the ranker
|
||||
ranker = RecentnessRanker(
|
||||
date_meta_field=test_input["date_meta_field"],
|
||||
ranking_mode=test_input["ranking_mode"],
|
||||
weight=test_input["weight"],
|
||||
)
|
||||
predict_exception = None
|
||||
results = []
|
||||
try:
|
||||
# Run predict_batch with a list as input
|
||||
results = ranker.predict_batch(queries="", documents=docs, top_k=test_input["top_k"])
|
||||
except Exception as e:
|
||||
predict_exception = e
|
||||
|
||||
check_results(results, test_input, warnings_list, caplog, predict_exception)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("test_input", recency_tests_inputs)
|
||||
def test_recentness_ranker_batch_list_of_lists(caplog, test_input):
|
||||
# Create a set of docs
|
||||
docs = []
|
||||
for doc in test_input["docs"]:
|
||||
docs.append(Document(content="abc", **doc))
|
||||
|
||||
# catch warnings to check they are properly issued
|
||||
with warnings.catch_warnings(record=True) as warnings_list:
|
||||
# Initialize the ranker
|
||||
ranker = RecentnessRanker(
|
||||
date_meta_field=test_input["date_meta_field"],
|
||||
ranking_mode=test_input["ranking_mode"],
|
||||
weight=test_input["weight"],
|
||||
)
|
||||
|
||||
predict_exception = None
|
||||
results = []
|
||||
try:
|
||||
# Run predict_batch with a list of lists as input
|
||||
results = ranker.predict_batch(queries="", documents=[docs, copy.deepcopy(docs)], top_k=test_input["top_k"])
|
||||
except Exception as e:
|
||||
predict_exception = e
|
||||
|
||||
check_results(results, test_input, warnings_list, caplog, predict_exception, list_of_lists=True)
|
||||
|
||||
|
||||
def check_results(results, test_input, warnings_list, caplog, exception, list_of_lists=False):
|
||||
expected_logs_count = 1
|
||||
if list_of_lists:
|
||||
expected_logs_count = 2
|
||||
|
||||
if "expected_exception" in test_input and test_input["expected_exception"] is not None:
|
||||
assert exception.message == test_input["expected_exception"].message
|
||||
assert type(exception) == type(test_input["expected_exception"])
|
||||
return
|
||||
else:
|
||||
assert exception is None
|
||||
|
||||
# Check that no warnings were thrown, if we are not expecting any
|
||||
if "expected_warning" not in test_input or test_input["expected_warning"] == []:
|
||||
assert len(warnings_list) == 0
|
||||
# Check that all expected warnings happened, and only those
|
||||
else:
|
||||
assert len(warnings_list) == len(test_input["expected_warning"])
|
||||
for i in range(len(warnings_list)):
|
||||
assert test_input["expected_warning"][int(i)] == str(warnings_list[i].message)
|
||||
|
||||
# If we expect logging, compare them one by one
|
||||
if "expected_logs" not in test_input or test_input["expected_logs"] == []:
|
||||
assert len(caplog.record_tuples) == 0
|
||||
else:
|
||||
assert expected_logs_count * len(test_input["expected_logs"]) == len(caplog.record_tuples)
|
||||
for i in range(len(caplog.record_tuples)):
|
||||
assert test_input["expected_logs"][int(i / expected_logs_count)] == caplog.record_tuples[i]
|
||||
|
||||
if not list_of_lists:
|
||||
check_result_content(results, test_input)
|
||||
else:
|
||||
for i in results:
|
||||
check_result_content(i, test_input)
|
||||
|
||||
|
||||
# Verify the results, that the order and the score of the documents match
|
||||
def check_result_content(results, test_input):
|
||||
assert len(results) == len(test_input["expected_order"])
|
||||
for i in range(len(test_input["expected_order"])):
|
||||
assert test_input["expected_order"][i] == results[i].id
|
||||
if "expected_scores" in test_input:
|
||||
assert test_input["expected_scores"][results[i].id] == results[i].score
|
||||
else:
|
||||
assert results[i].score is None
|
||||
|
Loading…
x
Reference in New Issue
Block a user