feat: RecentnessRanker (#5301)

* recency reranker code

* removed

* readd

* edited code

* edit

* mypy test fix

* adding warnings for score method

* fix

* fix

* adding paper link

* comments implementation

* change to predict and predict_batch

* change to predict and predict_batch 2

* adding unit test

* fixes

* small fixes

* fix for unit test

* table driven test

* small fixes

* small fixes2

* adding predict_batch tests

* add recentness_ranker to api reference docs

* implementing feedback

* implementing feedback2

* implementing feedback3

* implementing feedback4

* implementing feedback5

* remove document_map, remove final check if score is not None

* add final check if doc score is not None for mypy

---------

Co-authored-by: Darja Fokina <daria.f93@gmail.com>
Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
elundaeva 2023-07-20 16:20:45 +02:00 committed by GitHub
parent c2506866bd
commit 612c6779fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 473 additions and 2 deletions

View File

@ -1,7 +1,7 @@
loaders:
- type: python
search_path: [../../../haystack/nodes/ranker]
modules: ["base", "sentence_transformers"]
modules: ["base", "sentence_transformers", "recentness_ranker"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter

View File

@ -0,0 +1,188 @@
import logging
import warnings
from collections import defaultdict
from typing import List, Union, Optional, Dict, Literal
from dateutil.parser import parse, ParserError
from haystack.errors import NodeError
from haystack.nodes.ranker.base import BaseRanker
from haystack.schema import Document
logger = logging.getLogger(__name__)
class RecentnessRanker(BaseRanker):
outgoing_edges = 1
def __init__(
self,
date_meta_field: str,
weight: float = 0.5,
top_k: Optional[int] = None,
ranking_mode: Literal["reciprocal_rank_fusion", "score"] = "reciprocal_rank_fusion",
):
"""
This Node is used to rerank retrieved documents based on their age. Newer documents will rank higher.
The importance of recentness is parametrized through the weight parameter.
:param date_meta_field: Identifier pointing to the date field in the metadata.
This is a required parameter, since we need dates for sorting.
:param weight: in range [0,1].
0 disables sorting by age.
0.5 content and age have the same impact.
1 means sorting only by age, most recent comes first.
:param top_k: (optional) How many documents to return. If not provided, all documents will be returned.
It can make sense to have large top-k values from the initial retrievers and filter docs down in the
RecentnessRanker with this top_k parameter.
:param ranking_mode: The mode used to combine retriever and recentness. Possible values are 'reciprocal_rank_fusion' (default) and 'score'.
Make sure to use 'score' mode only with retrievers/rankers that give back OK score in range [0,1].
"""
super().__init__()
self.date_meta_field = date_meta_field
self.weight = weight
self.top_k = top_k
self.ranking_mode = ranking_mode
if self.weight < 0 or self.weight > 1:
raise NodeError(
"""
Param <weight> needs to be '0', '0.5' or '1' but was set to '{}'. \n
Please change param <weight> when initializing the RecentnessRanker.
""".format(
self.weight
)
)
# pylint: disable=arguments-differ
def predict( # type: ignore
self, query: str, documents: List[Document], top_k: Optional[int] = None
) -> List[Document]:
"""
This method is used to rank a list of documents based on their age and relevance by:
1. Adjusting the relevance score from the previous node (or, for RRF, calculating it from scratch, then adjusting) based on the chosen weight in initial parameters.
2. Sorting the documents based on their age in the metadata, calculating the recentness score, adjusting it by weight as well.
3. Returning top-k documents (or all, if top-k not provided) in the documents dictionary sorted by final score (relevance score + recentness score).
:param query: Not used in practice (so can be left blank), as this ranker does not perform sorting based on semantic closeness of documents to the query.
:param documents: Documents provided for ranking.
:param top_k: (optional) How many documents to return at the end. If not provided, all documents will be returned, sorted by relevance and recentness (adjusted by weight).
"""
try:
sorted_by_date = sorted(documents, reverse=True, key=lambda x: parse(x.meta[self.date_meta_field]))
except KeyError:
raise NodeError(
"""
Param <date_meta_field> was set to '{}', but document(s) {} do not contain this metadata key.\n
Please double-check the names of existing metadata fields of your documents \n
and set <date_meta_field> to the name of the field that contains dates.
""".format(
self.date_meta_field,
",".join([doc.id for doc in documents if self.date_meta_field not in doc.meta]),
)
)
except ParserError:
logger.error(
"""
Could not parse date information for dates: %s\n
Continuing without sorting by date.
""",
" - ".join([doc.meta.get(self.date_meta_field, "identifier wrong") for doc in documents]),
)
return documents
# merge scores for documents sorted both by content and by date.
# If ranking mode is set to 'reciprocal_rank_fusion', then that is used to combine previous ranking with recency ranking.
# If ranking mode is set to 'score', then documents will be assigned a recency score in [0,1] and will be re-ranked based on both their recency score and their pre-existing relevance score.
scores_map: Dict = defaultdict(int)
if self.ranking_mode not in ["reciprocal_rank_fusion", "score"]:
raise NodeError(
"""
Param <ranking_mode> needs to be 'reciprocal_rank_fusion' or 'score' but was set to '{}'. \n
Please change the <ranking_mode> when initializing the RecentnessRanker.
""".format(
self.ranking_mode
)
)
for i, doc in enumerate(documents):
if self.ranking_mode == "reciprocal_rank_fusion":
scores_map[doc.id] += self._calculate_rrf(rank=i) * (1 - self.weight)
elif self.ranking_mode == "score":
score = float(0)
if doc.score is None:
warnings.warn("The score was not provided; defaulting to 0")
elif doc.score < 0 or doc.score > 1:
warnings.warn(
"The score {} for document {} is outside the [0,1] range; defaulting to 0".format(
doc.score, doc.id
)
)
else:
score = doc.score
scores_map[doc.id] += score * (1 - self.weight)
for i, doc in enumerate(sorted_by_date):
if self.ranking_mode == "reciprocal_rank_fusion":
scores_map[doc.id] += self._calculate_rrf(rank=i) * self.weight
elif self.ranking_mode == "score":
scores_map[doc.id] += self._calc_recentness_score(rank=i, amount=len(sorted_by_date)) * self.weight
top_k = top_k or self.top_k or len(documents)
for doc in documents:
doc.score = scores_map[doc.id]
return sorted(documents, key=lambda doc: doc.score if doc.score is not None else -1, reverse=True)[:top_k]
# pylint: disable=arguments-differ
def predict_batch( # type: ignore
self,
queries: List[str],
documents: Union[List[Document], List[List[Document]]],
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Union[List[Document], List[List[Document]]]:
"""
This method is used to rank A) a list or B) a list of lists (in case the previous node is JoinDocuments) of documents based on their age and relevance.
In case A, the predict method defined earlier is applied to the provided list.
In case B, predict method is applied to each individual list in the list of lists provided, then the results are returned as list of lists.
:param queries: Not used in practice (so can be left blank), as this ranker does not perform sorting based on semantic closeness of documents to the query.
:param documents: Documents provided for ranking in a list or a list of lists.
:param top_k: (optional) How many documents to return at the end (per list). If not provided, all documents will be returned, sorted by relevance and recentness (adjusted by weight).
:param batch_size: Not used in practice, so can be left blank.
"""
if isinstance(documents[0], Document):
return self.predict("", documents=documents, top_k=top_k) # type: ignore
nested_docs = []
for docs in documents:
results = self.predict("", documents=docs, top_k=top_k) # type: ignore
nested_docs.append(results)
return nested_docs
@staticmethod
def _calculate_rrf(rank: int, k: int = 61) -> float:
"""
Calculates the reciprocal rank fusion. The constant K is set to 61 (60 was suggested by the original paper,
plus 1 as python lists are 0-based and the paper [https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf] used 1-based ranking).
"""
return 1 / (k + rank)
@staticmethod
def _calc_recentness_score(rank: int, amount: int) -> float:
"""
Calculate recentness score as a linear score between most recent and oldest document.
This linear scaling is useful to
a) reduce the effect of outliers and
b) create recentness scoress that are meaningfully distributed in [0,1],
similar to scores coming from a retriever/ranker.
"""
return (amount - rank) / amount

View File

@ -1,12 +1,16 @@
import pytest
import math
import warnings
import logging
import copy
from unittest.mock import patch
import torch
from haystack.schema import Document
from haystack.nodes.ranker.base import BaseRanker
from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker
from haystack.errors import HaystackError
from haystack.nodes.ranker.recentness_ranker import RecentnessRanker
from haystack.errors import HaystackError, NodeError
@pytest.fixture
@ -499,3 +503,282 @@ def test_cohere_ranker_batch_multiple_queries_multiple_doc_lists(docs, mock_cohe
assert isinstance(results[0], list)
assert results[0][0] == docs[4]
assert results[1][0] == docs[4]
recency_tests_inputs = [
# Score ranking mode works as expected
pytest.param(
{
"docs": [
{"meta": {"date": "2021-02-11"}, "score": 0.3, "id": "1"},
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
],
"weight": 0.5,
"date_meta_field": "date",
"top_k": 2,
"ranking_mode": "score",
"expected_scores": {"1": 0.4833333333333333, "2": 0.7},
"expected_order": ["2", "1"],
"expected_logs": [],
"expected_warning": "",
},
id="Score ranking mode works as expected",
),
# RRF ranking mode works as expected
pytest.param(
{
"docs": [
{"meta": {"date": "2021-02-11"}, "id": "1"},
{"meta": {"date": "2018-02-11"}, "id": "2"},
{"meta": {"date": "2020-02-11"}, "id": "3"},
],
"weight": 0.5,
"date_meta_field": "date",
"top_k": 2,
"ranking_mode": "reciprocal_rank_fusion",
"expected_scores": {"1": 0.01639344262295082, "2": 0.016001024065540194},
"expected_order": ["1", "2"],
"expected_logs": [],
"expected_warning": "",
},
id="RRF ranking mode works as expected",
),
# Wrong field to find the date
pytest.param(
{
"docs": [
{"meta": {"data": "2021-02-11"}, "score": 0.3, "id": "1"},
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
],
"weight": 0.5,
"date_meta_field": "date",
"expected_scores": {"1": 0.3, "2": 0.4, "3": 0.6},
"expected_order": ["1", "2", "3"],
"expected_exception": NodeError(
"""
Param <date_meta_field> was set to 'date', but document(s) 1 do not contain this metadata key.\n
Please double-check the names of existing metadata fields of your documents \n
and set <date_meta_field> to the name of the field that contains dates.
"""
),
"top_k": 2,
"ranking_mode": "score",
},
id="Wrong field to find the date",
),
# Date unparsable
pytest.param(
{
"docs": [
{"meta": {"date": "abcd"}, "id": "1"},
{"meta": {"date": "2024-02-11"}, "id": "2"},
{"meta": {"date": "2020-02-11"}, "id": "3"},
],
"weight": 0.5,
"date_meta_field": "date",
"expected_order": ["1", "2", "3"],
"expected_logs": [
(
"haystack.nodes.ranker.recentness_ranker",
logging.ERROR,
"""
Could not parse date information for dates: abcd - 2024-02-11 - 2020-02-11\n
Continuing without sorting by date.
""",
)
],
"top_k": 2,
"ranking_mode": "reciprocal_rank_fusion",
},
id="Date unparsable",
),
# Wrong score, outside of bonds
pytest.param(
{
"docs": [
{"meta": {"date": "2021-02-11"}, "score": 1.3, "id": "1"},
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
],
"weight": 0.5,
"date_meta_field": "date",
"top_k": 2,
"ranking_mode": "score",
"expected_scores": {"1": 0.5, "2": 0.7, "3": 0.4666666666666667},
"expected_order": ["2", "3"],
"expected_warning": ["The score 1.3 for document 1 is outside the [0,1] range; defaulting to 0"],
},
id="Wrong score, outside of bonds",
),
# Wrong score, not provided
pytest.param(
{
"docs": [
{"meta": {"date": "2021-02-11"}, "id": "1"},
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
],
"weight": 0.5,
"date_meta_field": "date",
"top_k": 2,
"ranking_mode": "score",
"expected_scores": {"1": 0.5, "2": 0.7, "3": 0.4666666666666667},
"expected_order": ["2", "3"],
"expected_warning": ["The score was not provided; defaulting to 0"],
},
id="Wrong score, not provided",
),
# Wrong ranking mode provided
pytest.param(
{
"docs": [
{"meta": {"date": "2021-02-11"}, "id": "1"},
{"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"},
{"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"},
],
"weight": 0.5,
"date_meta_field": "date",
"top_k": 2,
"ranking_mode": "blablabla",
"expected_scores": {"1": 0.01626123744050767, "2": 0.01626123744050767},
"expected_order": ["1", "2"],
"expected_exception": NodeError(
"""
Param <ranking_mode> needs to be 'reciprocal_rank_fusion' or 'score' but was set to 'blablabla'. \n
Please change the <ranking_mode> when initializing the RecentnessRanker.
"""
),
},
id="Wrong ranking mode provided",
),
]
@pytest.mark.unit
@pytest.mark.parametrize("test_input", recency_tests_inputs)
def test_recentness_ranker(caplog, test_input):
# Create a set of docs
docs = []
for doc in test_input["docs"]:
docs.append(Document(content="abc", **doc))
# catch warnings to check they are properly issued
with warnings.catch_warnings(record=True) as warnings_list:
# Initialize the ranker
ranker = RecentnessRanker(
date_meta_field=test_input["date_meta_field"],
ranking_mode=test_input["ranking_mode"],
weight=test_input["weight"],
)
predict_exception = None
results = []
try:
results = ranker.predict(query="", documents=docs, top_k=test_input["top_k"])
except Exception as e:
predict_exception = e
check_results(results, test_input, warnings_list, caplog, predict_exception)
@pytest.mark.unit
@pytest.mark.parametrize("test_input", recency_tests_inputs)
def test_recentness_ranker_batch_list(caplog, test_input):
# Create a set of docs
docs = []
for doc in test_input["docs"]:
docs.append(Document(content="abc", **doc))
# catch warnings to check they are properly issued
with warnings.catch_warnings(record=True) as warnings_list:
# Initialize the ranker
ranker = RecentnessRanker(
date_meta_field=test_input["date_meta_field"],
ranking_mode=test_input["ranking_mode"],
weight=test_input["weight"],
)
predict_exception = None
results = []
try:
# Run predict_batch with a list as input
results = ranker.predict_batch(queries="", documents=docs, top_k=test_input["top_k"])
except Exception as e:
predict_exception = e
check_results(results, test_input, warnings_list, caplog, predict_exception)
@pytest.mark.unit
@pytest.mark.parametrize("test_input", recency_tests_inputs)
def test_recentness_ranker_batch_list_of_lists(caplog, test_input):
# Create a set of docs
docs = []
for doc in test_input["docs"]:
docs.append(Document(content="abc", **doc))
# catch warnings to check they are properly issued
with warnings.catch_warnings(record=True) as warnings_list:
# Initialize the ranker
ranker = RecentnessRanker(
date_meta_field=test_input["date_meta_field"],
ranking_mode=test_input["ranking_mode"],
weight=test_input["weight"],
)
predict_exception = None
results = []
try:
# Run predict_batch with a list of lists as input
results = ranker.predict_batch(queries="", documents=[docs, copy.deepcopy(docs)], top_k=test_input["top_k"])
except Exception as e:
predict_exception = e
check_results(results, test_input, warnings_list, caplog, predict_exception, list_of_lists=True)
def check_results(results, test_input, warnings_list, caplog, exception, list_of_lists=False):
expected_logs_count = 1
if list_of_lists:
expected_logs_count = 2
if "expected_exception" in test_input and test_input["expected_exception"] is not None:
assert exception.message == test_input["expected_exception"].message
assert type(exception) == type(test_input["expected_exception"])
return
else:
assert exception is None
# Check that no warnings were thrown, if we are not expecting any
if "expected_warning" not in test_input or test_input["expected_warning"] == []:
assert len(warnings_list) == 0
# Check that all expected warnings happened, and only those
else:
assert len(warnings_list) == len(test_input["expected_warning"])
for i in range(len(warnings_list)):
assert test_input["expected_warning"][int(i)] == str(warnings_list[i].message)
# If we expect logging, compare them one by one
if "expected_logs" not in test_input or test_input["expected_logs"] == []:
assert len(caplog.record_tuples) == 0
else:
assert expected_logs_count * len(test_input["expected_logs"]) == len(caplog.record_tuples)
for i in range(len(caplog.record_tuples)):
assert test_input["expected_logs"][int(i / expected_logs_count)] == caplog.record_tuples[i]
if not list_of_lists:
check_result_content(results, test_input)
else:
for i in results:
check_result_content(i, test_input)
# Verify the results, that the order and the score of the documents match
def check_result_content(results, test_input):
assert len(results) == len(test_input["expected_order"])
for i in range(len(test_input["expected_order"])):
assert test_input["expected_order"][i] == results[i].id
if "expected_scores" in test_input:
assert test_input["expected_scores"][results[i].id] == results[i].score
else:
assert results[i].score is None