diff --git a/docs/pydoc/config/ranker.yml b/docs/pydoc/config/ranker.yml index 5c0917dd7..ff33cc3cf 100644 --- a/docs/pydoc/config/ranker.yml +++ b/docs/pydoc/config/ranker.yml @@ -1,7 +1,7 @@ loaders: - type: python search_path: [../../../haystack/nodes/ranker] - modules: ["base", "sentence_transformers"] + modules: ["base", "sentence_transformers", "recentness_ranker"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/nodes/ranker/recentness_ranker.py b/haystack/nodes/ranker/recentness_ranker.py new file mode 100644 index 000000000..a356f1301 --- /dev/null +++ b/haystack/nodes/ranker/recentness_ranker.py @@ -0,0 +1,188 @@ +import logging +import warnings +from collections import defaultdict +from typing import List, Union, Optional, Dict, Literal + +from dateutil.parser import parse, ParserError +from haystack.errors import NodeError + +from haystack.nodes.ranker.base import BaseRanker +from haystack.schema import Document + +logger = logging.getLogger(__name__) + + +class RecentnessRanker(BaseRanker): + outgoing_edges = 1 + + def __init__( + self, + date_meta_field: str, + weight: float = 0.5, + top_k: Optional[int] = None, + ranking_mode: Literal["reciprocal_rank_fusion", "score"] = "reciprocal_rank_fusion", + ): + """ + This Node is used to rerank retrieved documents based on their age. Newer documents will rank higher. + The importance of recentness is parametrized through the weight parameter. + + :param date_meta_field: Identifier pointing to the date field in the metadata. + This is a required parameter, since we need dates for sorting. + :param weight: in range [0,1]. + 0 disables sorting by age. + 0.5 content and age have the same impact. + 1 means sorting only by age, most recent comes first. + :param top_k: (optional) How many documents to return. If not provided, all documents will be returned. + It can make sense to have large top-k values from the initial retrievers and filter docs down in the + RecentnessRanker with this top_k parameter. + :param ranking_mode: The mode used to combine retriever and recentness. Possible values are 'reciprocal_rank_fusion' (default) and 'score'. + Make sure to use 'score' mode only with retrievers/rankers that give back OK score in range [0,1]. + """ + + super().__init__() + self.date_meta_field = date_meta_field + self.weight = weight + self.top_k = top_k + self.ranking_mode = ranking_mode + + if self.weight < 0 or self.weight > 1: + raise NodeError( + """ + Param needs to be '0', '0.5' or '1' but was set to '{}'. \n + Please change param when initializing the RecentnessRanker. + """.format( + self.weight + ) + ) + + # pylint: disable=arguments-differ + def predict( # type: ignore + self, query: str, documents: List[Document], top_k: Optional[int] = None + ) -> List[Document]: + """ + This method is used to rank a list of documents based on their age and relevance by: + 1. Adjusting the relevance score from the previous node (or, for RRF, calculating it from scratch, then adjusting) based on the chosen weight in initial parameters. + 2. Sorting the documents based on their age in the metadata, calculating the recentness score, adjusting it by weight as well. + 3. Returning top-k documents (or all, if top-k not provided) in the documents dictionary sorted by final score (relevance score + recentness score). + + :param query: Not used in practice (so can be left blank), as this ranker does not perform sorting based on semantic closeness of documents to the query. + :param documents: Documents provided for ranking. + :param top_k: (optional) How many documents to return at the end. If not provided, all documents will be returned, sorted by relevance and recentness (adjusted by weight). + """ + + try: + sorted_by_date = sorted(documents, reverse=True, key=lambda x: parse(x.meta[self.date_meta_field])) + except KeyError: + raise NodeError( + """ + Param was set to '{}', but document(s) {} do not contain this metadata key.\n + Please double-check the names of existing metadata fields of your documents \n + and set to the name of the field that contains dates. + """.format( + self.date_meta_field, + ",".join([doc.id for doc in documents if self.date_meta_field not in doc.meta]), + ) + ) + + except ParserError: + logger.error( + """ + Could not parse date information for dates: %s\n + Continuing without sorting by date. + """, + " - ".join([doc.meta.get(self.date_meta_field, "identifier wrong") for doc in documents]), + ) + + return documents + + # merge scores for documents sorted both by content and by date. + # If ranking mode is set to 'reciprocal_rank_fusion', then that is used to combine previous ranking with recency ranking. + # If ranking mode is set to 'score', then documents will be assigned a recency score in [0,1] and will be re-ranked based on both their recency score and their pre-existing relevance score. + scores_map: Dict = defaultdict(int) + if self.ranking_mode not in ["reciprocal_rank_fusion", "score"]: + raise NodeError( + """ + Param needs to be 'reciprocal_rank_fusion' or 'score' but was set to '{}'. \n + Please change the when initializing the RecentnessRanker. + """.format( + self.ranking_mode + ) + ) + + for i, doc in enumerate(documents): + if self.ranking_mode == "reciprocal_rank_fusion": + scores_map[doc.id] += self._calculate_rrf(rank=i) * (1 - self.weight) + elif self.ranking_mode == "score": + score = float(0) + if doc.score is None: + warnings.warn("The score was not provided; defaulting to 0") + elif doc.score < 0 or doc.score > 1: + warnings.warn( + "The score {} for document {} is outside the [0,1] range; defaulting to 0".format( + doc.score, doc.id + ) + ) + else: + score = doc.score + + scores_map[doc.id] += score * (1 - self.weight) + + for i, doc in enumerate(sorted_by_date): + if self.ranking_mode == "reciprocal_rank_fusion": + scores_map[doc.id] += self._calculate_rrf(rank=i) * self.weight + elif self.ranking_mode == "score": + scores_map[doc.id] += self._calc_recentness_score(rank=i, amount=len(sorted_by_date)) * self.weight + + top_k = top_k or self.top_k or len(documents) + + for doc in documents: + doc.score = scores_map[doc.id] + + return sorted(documents, key=lambda doc: doc.score if doc.score is not None else -1, reverse=True)[:top_k] + + # pylint: disable=arguments-differ + def predict_batch( # type: ignore + self, + queries: List[str], + documents: Union[List[Document], List[List[Document]]], + top_k: Optional[int] = None, + batch_size: Optional[int] = None, + ) -> Union[List[Document], List[List[Document]]]: + """ + This method is used to rank A) a list or B) a list of lists (in case the previous node is JoinDocuments) of documents based on their age and relevance. + In case A, the predict method defined earlier is applied to the provided list. + In case B, predict method is applied to each individual list in the list of lists provided, then the results are returned as list of lists. + + :param queries: Not used in practice (so can be left blank), as this ranker does not perform sorting based on semantic closeness of documents to the query. + :param documents: Documents provided for ranking in a list or a list of lists. + :param top_k: (optional) How many documents to return at the end (per list). If not provided, all documents will be returned, sorted by relevance and recentness (adjusted by weight). + :param batch_size: Not used in practice, so can be left blank. + """ + + if isinstance(documents[0], Document): + return self.predict("", documents=documents, top_k=top_k) # type: ignore + nested_docs = [] + for docs in documents: + results = self.predict("", documents=docs, top_k=top_k) # type: ignore + nested_docs.append(results) + + return nested_docs + + @staticmethod + def _calculate_rrf(rank: int, k: int = 61) -> float: + """ + Calculates the reciprocal rank fusion. The constant K is set to 61 (60 was suggested by the original paper, + plus 1 as python lists are 0-based and the paper [https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf] used 1-based ranking). + """ + return 1 / (k + rank) + + @staticmethod + def _calc_recentness_score(rank: int, amount: int) -> float: + """ + Calculate recentness score as a linear score between most recent and oldest document. + This linear scaling is useful to + a) reduce the effect of outliers and + b) create recentness scoress that are meaningfully distributed in [0,1], + similar to scores coming from a retriever/ranker. + """ + return (amount - rank) / amount diff --git a/test/nodes/test_ranker.py b/test/nodes/test_ranker.py index c5faef046..baedd0086 100644 --- a/test/nodes/test_ranker.py +++ b/test/nodes/test_ranker.py @@ -1,12 +1,16 @@ import pytest import math +import warnings +import logging +import copy from unittest.mock import patch import torch from haystack.schema import Document from haystack.nodes.ranker.base import BaseRanker from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker -from haystack.errors import HaystackError +from haystack.nodes.ranker.recentness_ranker import RecentnessRanker +from haystack.errors import HaystackError, NodeError @pytest.fixture @@ -499,3 +503,282 @@ def test_cohere_ranker_batch_multiple_queries_multiple_doc_lists(docs, mock_cohe assert isinstance(results[0], list) assert results[0][0] == docs[4] assert results[1][0] == docs[4] + + +recency_tests_inputs = [ + # Score ranking mode works as expected + pytest.param( + { + "docs": [ + {"meta": {"date": "2021-02-11"}, "score": 0.3, "id": "1"}, + {"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"}, + {"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"}, + ], + "weight": 0.5, + "date_meta_field": "date", + "top_k": 2, + "ranking_mode": "score", + "expected_scores": {"1": 0.4833333333333333, "2": 0.7}, + "expected_order": ["2", "1"], + "expected_logs": [], + "expected_warning": "", + }, + id="Score ranking mode works as expected", + ), + # RRF ranking mode works as expected + pytest.param( + { + "docs": [ + {"meta": {"date": "2021-02-11"}, "id": "1"}, + {"meta": {"date": "2018-02-11"}, "id": "2"}, + {"meta": {"date": "2020-02-11"}, "id": "3"}, + ], + "weight": 0.5, + "date_meta_field": "date", + "top_k": 2, + "ranking_mode": "reciprocal_rank_fusion", + "expected_scores": {"1": 0.01639344262295082, "2": 0.016001024065540194}, + "expected_order": ["1", "2"], + "expected_logs": [], + "expected_warning": "", + }, + id="RRF ranking mode works as expected", + ), + # Wrong field to find the date + pytest.param( + { + "docs": [ + {"meta": {"data": "2021-02-11"}, "score": 0.3, "id": "1"}, + {"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"}, + {"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"}, + ], + "weight": 0.5, + "date_meta_field": "date", + "expected_scores": {"1": 0.3, "2": 0.4, "3": 0.6}, + "expected_order": ["1", "2", "3"], + "expected_exception": NodeError( + """ + Param was set to 'date', but document(s) 1 do not contain this metadata key.\n + Please double-check the names of existing metadata fields of your documents \n + and set to the name of the field that contains dates. + """ + ), + "top_k": 2, + "ranking_mode": "score", + }, + id="Wrong field to find the date", + ), + # Date unparsable + pytest.param( + { + "docs": [ + {"meta": {"date": "abcd"}, "id": "1"}, + {"meta": {"date": "2024-02-11"}, "id": "2"}, + {"meta": {"date": "2020-02-11"}, "id": "3"}, + ], + "weight": 0.5, + "date_meta_field": "date", + "expected_order": ["1", "2", "3"], + "expected_logs": [ + ( + "haystack.nodes.ranker.recentness_ranker", + logging.ERROR, + """ + Could not parse date information for dates: abcd - 2024-02-11 - 2020-02-11\n + Continuing without sorting by date. + """, + ) + ], + "top_k": 2, + "ranking_mode": "reciprocal_rank_fusion", + }, + id="Date unparsable", + ), + # Wrong score, outside of bonds + pytest.param( + { + "docs": [ + {"meta": {"date": "2021-02-11"}, "score": 1.3, "id": "1"}, + {"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"}, + {"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"}, + ], + "weight": 0.5, + "date_meta_field": "date", + "top_k": 2, + "ranking_mode": "score", + "expected_scores": {"1": 0.5, "2": 0.7, "3": 0.4666666666666667}, + "expected_order": ["2", "3"], + "expected_warning": ["The score 1.3 for document 1 is outside the [0,1] range; defaulting to 0"], + }, + id="Wrong score, outside of bonds", + ), + # Wrong score, not provided + pytest.param( + { + "docs": [ + {"meta": {"date": "2021-02-11"}, "id": "1"}, + {"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"}, + {"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"}, + ], + "weight": 0.5, + "date_meta_field": "date", + "top_k": 2, + "ranking_mode": "score", + "expected_scores": {"1": 0.5, "2": 0.7, "3": 0.4666666666666667}, + "expected_order": ["2", "3"], + "expected_warning": ["The score was not provided; defaulting to 0"], + }, + id="Wrong score, not provided", + ), + # Wrong ranking mode provided + pytest.param( + { + "docs": [ + {"meta": {"date": "2021-02-11"}, "id": "1"}, + {"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"}, + {"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"}, + ], + "weight": 0.5, + "date_meta_field": "date", + "top_k": 2, + "ranking_mode": "blablabla", + "expected_scores": {"1": 0.01626123744050767, "2": 0.01626123744050767}, + "expected_order": ["1", "2"], + "expected_exception": NodeError( + """ + Param needs to be 'reciprocal_rank_fusion' or 'score' but was set to 'blablabla'. \n + Please change the when initializing the RecentnessRanker. + """ + ), + }, + id="Wrong ranking mode provided", + ), +] + + +@pytest.mark.unit +@pytest.mark.parametrize("test_input", recency_tests_inputs) +def test_recentness_ranker(caplog, test_input): + # Create a set of docs + docs = [] + for doc in test_input["docs"]: + docs.append(Document(content="abc", **doc)) + + # catch warnings to check they are properly issued + with warnings.catch_warnings(record=True) as warnings_list: + # Initialize the ranker + ranker = RecentnessRanker( + date_meta_field=test_input["date_meta_field"], + ranking_mode=test_input["ranking_mode"], + weight=test_input["weight"], + ) + predict_exception = None + results = [] + try: + results = ranker.predict(query="", documents=docs, top_k=test_input["top_k"]) + except Exception as e: + predict_exception = e + + check_results(results, test_input, warnings_list, caplog, predict_exception) + + +@pytest.mark.unit +@pytest.mark.parametrize("test_input", recency_tests_inputs) +def test_recentness_ranker_batch_list(caplog, test_input): + # Create a set of docs + docs = [] + for doc in test_input["docs"]: + docs.append(Document(content="abc", **doc)) + + # catch warnings to check they are properly issued + with warnings.catch_warnings(record=True) as warnings_list: + # Initialize the ranker + ranker = RecentnessRanker( + date_meta_field=test_input["date_meta_field"], + ranking_mode=test_input["ranking_mode"], + weight=test_input["weight"], + ) + predict_exception = None + results = [] + try: + # Run predict_batch with a list as input + results = ranker.predict_batch(queries="", documents=docs, top_k=test_input["top_k"]) + except Exception as e: + predict_exception = e + + check_results(results, test_input, warnings_list, caplog, predict_exception) + + +@pytest.mark.unit +@pytest.mark.parametrize("test_input", recency_tests_inputs) +def test_recentness_ranker_batch_list_of_lists(caplog, test_input): + # Create a set of docs + docs = [] + for doc in test_input["docs"]: + docs.append(Document(content="abc", **doc)) + + # catch warnings to check they are properly issued + with warnings.catch_warnings(record=True) as warnings_list: + # Initialize the ranker + ranker = RecentnessRanker( + date_meta_field=test_input["date_meta_field"], + ranking_mode=test_input["ranking_mode"], + weight=test_input["weight"], + ) + + predict_exception = None + results = [] + try: + # Run predict_batch with a list of lists as input + results = ranker.predict_batch(queries="", documents=[docs, copy.deepcopy(docs)], top_k=test_input["top_k"]) + except Exception as e: + predict_exception = e + + check_results(results, test_input, warnings_list, caplog, predict_exception, list_of_lists=True) + + +def check_results(results, test_input, warnings_list, caplog, exception, list_of_lists=False): + expected_logs_count = 1 + if list_of_lists: + expected_logs_count = 2 + + if "expected_exception" in test_input and test_input["expected_exception"] is not None: + assert exception.message == test_input["expected_exception"].message + assert type(exception) == type(test_input["expected_exception"]) + return + else: + assert exception is None + + # Check that no warnings were thrown, if we are not expecting any + if "expected_warning" not in test_input or test_input["expected_warning"] == []: + assert len(warnings_list) == 0 + # Check that all expected warnings happened, and only those + else: + assert len(warnings_list) == len(test_input["expected_warning"]) + for i in range(len(warnings_list)): + assert test_input["expected_warning"][int(i)] == str(warnings_list[i].message) + + # If we expect logging, compare them one by one + if "expected_logs" not in test_input or test_input["expected_logs"] == []: + assert len(caplog.record_tuples) == 0 + else: + assert expected_logs_count * len(test_input["expected_logs"]) == len(caplog.record_tuples) + for i in range(len(caplog.record_tuples)): + assert test_input["expected_logs"][int(i / expected_logs_count)] == caplog.record_tuples[i] + + if not list_of_lists: + check_result_content(results, test_input) + else: + for i in results: + check_result_content(i, test_input) + + +# Verify the results, that the order and the score of the documents match +def check_result_content(results, test_input): + assert len(results) == len(test_input["expected_order"]) + for i in range(len(test_input["expected_order"])): + assert test_input["expected_order"][i] == results[i].id + if "expected_scores" in test_input: + assert test_input["expected_scores"][results[i].id] == results[i].score + else: + assert results[i].score is None