mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-25 14:59:01 +00:00 
			
		
		
		
	feat: RecentnessRanker (#5301)
* recency reranker code * removed * readd * edited code * edit * mypy test fix * adding warnings for score method * fix * fix * adding paper link * comments implementation * change to predict and predict_batch * change to predict and predict_batch 2 * adding unit test * fixes * small fixes * fix for unit test * table driven test * small fixes * small fixes2 * adding predict_batch tests * add recentness_ranker to api reference docs * implementing feedback * implementing feedback2 * implementing feedback3 * implementing feedback4 * implementing feedback5 * remove document_map, remove final check if score is not None * add final check if doc score is not None for mypy --------- Co-authored-by: Darja Fokina <daria.f93@gmail.com> Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
		
							parent
							
								
									c2506866bd
								
							
						
					
					
						commit
						612c6779fb
					
				| @ -1,7 +1,7 @@ | |||||||
| loaders: | loaders: | ||||||
|   - type: python |   - type: python | ||||||
|     search_path: [../../../haystack/nodes/ranker] |     search_path: [../../../haystack/nodes/ranker] | ||||||
|     modules: ["base", "sentence_transformers"] |     modules: ["base", "sentence_transformers", "recentness_ranker"] | ||||||
|     ignore_when_discovered: ["__init__"] |     ignore_when_discovered: ["__init__"] | ||||||
| processors: | processors: | ||||||
|   - type: filter |   - type: filter | ||||||
|  | |||||||
							
								
								
									
										188
									
								
								haystack/nodes/ranker/recentness_ranker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										188
									
								
								haystack/nodes/ranker/recentness_ranker.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,188 @@ | |||||||
|  | import logging | ||||||
|  | import warnings | ||||||
|  | from collections import defaultdict | ||||||
|  | from typing import List, Union, Optional, Dict, Literal | ||||||
|  | 
 | ||||||
|  | from dateutil.parser import parse, ParserError | ||||||
|  | from haystack.errors import NodeError | ||||||
|  | 
 | ||||||
|  | from haystack.nodes.ranker.base import BaseRanker | ||||||
|  | from haystack.schema import Document | ||||||
|  | 
 | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class RecentnessRanker(BaseRanker): | ||||||
|  |     outgoing_edges = 1 | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         date_meta_field: str, | ||||||
|  |         weight: float = 0.5, | ||||||
|  |         top_k: Optional[int] = None, | ||||||
|  |         ranking_mode: Literal["reciprocal_rank_fusion", "score"] = "reciprocal_rank_fusion", | ||||||
|  |     ): | ||||||
|  |         """ | ||||||
|  |         This Node is used to rerank retrieved documents based on their age. Newer documents will rank higher. | ||||||
|  |         The importance of recentness is parametrized through the weight parameter. | ||||||
|  | 
 | ||||||
|  |         :param date_meta_field: Identifier pointing to the date field in the metadata. | ||||||
|  |                 This is a required parameter, since we need dates for sorting. | ||||||
|  |         :param weight: in range [0,1]. | ||||||
|  |                 0 disables sorting by age. | ||||||
|  |                 0.5 content and age have the same impact. | ||||||
|  |                 1 means sorting only by age, most recent comes first. | ||||||
|  |         :param top_k: (optional) How many documents to return. If not provided, all documents will be returned. | ||||||
|  |                 It can make sense to have large top-k values from the initial retrievers and filter docs down in the | ||||||
|  |                 RecentnessRanker with this top_k parameter. | ||||||
|  |         :param ranking_mode: The mode used to combine retriever and recentness. Possible values are 'reciprocal_rank_fusion' (default) and 'score'. | ||||||
|  |                 Make sure to use 'score' mode only with retrievers/rankers that give back OK score in range [0,1]. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         super().__init__() | ||||||
|  |         self.date_meta_field = date_meta_field | ||||||
|  |         self.weight = weight | ||||||
|  |         self.top_k = top_k | ||||||
|  |         self.ranking_mode = ranking_mode | ||||||
|  | 
 | ||||||
|  |         if self.weight < 0 or self.weight > 1: | ||||||
|  |             raise NodeError( | ||||||
|  |                 """ | ||||||
|  |                 Param <weight> needs to be '0', '0.5' or '1' but was set to '{}'. \n | ||||||
|  |                 Please change param <weight> when initializing the RecentnessRanker. | ||||||
|  |                 """.format( | ||||||
|  |                     self.weight | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |     # pylint: disable=arguments-differ | ||||||
|  |     def predict(  # type: ignore | ||||||
|  |         self, query: str, documents: List[Document], top_k: Optional[int] = None | ||||||
|  |     ) -> List[Document]: | ||||||
|  |         """ | ||||||
|  |         This method is used to rank a list of documents based on their age and relevance by: | ||||||
|  |         1. Adjusting the relevance score from the previous node (or, for RRF, calculating it from scratch, then adjusting) based on the chosen weight in initial parameters. | ||||||
|  |         2. Sorting the documents based on their age in the metadata, calculating the recentness score, adjusting it by weight as well. | ||||||
|  |         3. Returning top-k documents (or all, if top-k not provided) in the documents dictionary sorted by final score (relevance score + recentness score). | ||||||
|  | 
 | ||||||
|  |         :param query: Not used in practice (so can be left blank), as this ranker does not perform sorting based on semantic closeness of documents to the query. | ||||||
|  |         :param documents: Documents provided for ranking. | ||||||
|  |         :param top_k: (optional) How many documents to return at the end. If not provided, all documents will be returned, sorted by relevance and recentness (adjusted by weight). | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         try: | ||||||
|  |             sorted_by_date = sorted(documents, reverse=True, key=lambda x: parse(x.meta[self.date_meta_field])) | ||||||
|  |         except KeyError: | ||||||
|  |             raise NodeError( | ||||||
|  |                 """ | ||||||
|  |                 Param <date_meta_field> was set to '{}', but document(s) {} do not contain this metadata key.\n | ||||||
|  |                 Please double-check the names of existing metadata fields of your documents \n | ||||||
|  |                 and set <date_meta_field> to the name of the field that contains dates. | ||||||
|  |                 """.format( | ||||||
|  |                     self.date_meta_field, | ||||||
|  |                     ",".join([doc.id for doc in documents if self.date_meta_field not in doc.meta]), | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         except ParserError: | ||||||
|  |             logger.error( | ||||||
|  |                 """ | ||||||
|  |                 Could not parse date information for dates: %s\n | ||||||
|  |                 Continuing without sorting by date. | ||||||
|  |                 """, | ||||||
|  |                 " - ".join([doc.meta.get(self.date_meta_field, "identifier wrong") for doc in documents]), | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             return documents | ||||||
|  | 
 | ||||||
|  |         # merge scores for documents sorted both by content and by date. | ||||||
|  |         # If ranking mode is set to 'reciprocal_rank_fusion', then that is used to combine previous ranking with recency ranking. | ||||||
|  |         # If ranking mode is set to 'score', then documents will be assigned a recency score in [0,1] and will be re-ranked based on both their recency score and their pre-existing relevance score. | ||||||
|  |         scores_map: Dict = defaultdict(int) | ||||||
|  |         if self.ranking_mode not in ["reciprocal_rank_fusion", "score"]: | ||||||
|  |             raise NodeError( | ||||||
|  |                 """ | ||||||
|  |                 Param <ranking_mode> needs to be 'reciprocal_rank_fusion' or 'score' but was set to '{}'. \n | ||||||
|  |                 Please change the <ranking_mode> when initializing the RecentnessRanker. | ||||||
|  |                 """.format( | ||||||
|  |                     self.ranking_mode | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         for i, doc in enumerate(documents): | ||||||
|  |             if self.ranking_mode == "reciprocal_rank_fusion": | ||||||
|  |                 scores_map[doc.id] += self._calculate_rrf(rank=i) * (1 - self.weight) | ||||||
|  |             elif self.ranking_mode == "score": | ||||||
|  |                 score = float(0) | ||||||
|  |                 if doc.score is None: | ||||||
|  |                     warnings.warn("The score was not provided; defaulting to 0") | ||||||
|  |                 elif doc.score < 0 or doc.score > 1: | ||||||
|  |                     warnings.warn( | ||||||
|  |                         "The score {} for document {} is outside the [0,1] range; defaulting to 0".format( | ||||||
|  |                             doc.score, doc.id | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|  |                 else: | ||||||
|  |                     score = doc.score | ||||||
|  | 
 | ||||||
|  |                 scores_map[doc.id] += score * (1 - self.weight) | ||||||
|  | 
 | ||||||
|  |         for i, doc in enumerate(sorted_by_date): | ||||||
|  |             if self.ranking_mode == "reciprocal_rank_fusion": | ||||||
|  |                 scores_map[doc.id] += self._calculate_rrf(rank=i) * self.weight | ||||||
|  |             elif self.ranking_mode == "score": | ||||||
|  |                 scores_map[doc.id] += self._calc_recentness_score(rank=i, amount=len(sorted_by_date)) * self.weight | ||||||
|  | 
 | ||||||
|  |         top_k = top_k or self.top_k or len(documents) | ||||||
|  | 
 | ||||||
|  |         for doc in documents: | ||||||
|  |             doc.score = scores_map[doc.id] | ||||||
|  | 
 | ||||||
|  |         return sorted(documents, key=lambda doc: doc.score if doc.score is not None else -1, reverse=True)[:top_k] | ||||||
|  | 
 | ||||||
|  |     # pylint: disable=arguments-differ | ||||||
|  |     def predict_batch(  # type: ignore | ||||||
|  |         self, | ||||||
|  |         queries: List[str], | ||||||
|  |         documents: Union[List[Document], List[List[Document]]], | ||||||
|  |         top_k: Optional[int] = None, | ||||||
|  |         batch_size: Optional[int] = None, | ||||||
|  |     ) -> Union[List[Document], List[List[Document]]]: | ||||||
|  |         """ | ||||||
|  |         This method is used to rank A) a list or B) a list of lists (in case the previous node is JoinDocuments) of documents based on their age and relevance. | ||||||
|  |         In case A, the predict method defined earlier is applied to the provided list. | ||||||
|  |         In case B, predict method is applied to each individual list in the list of lists provided, then the results are returned as list of lists. | ||||||
|  | 
 | ||||||
|  |         :param queries: Not used in practice (so can be left blank), as this ranker does not perform sorting based on semantic closeness of documents to the query. | ||||||
|  |         :param documents: Documents provided for ranking in a list or a list of lists. | ||||||
|  |         :param top_k: (optional) How many documents to return at the end (per list). If not provided, all documents will be returned, sorted by relevance and recentness (adjusted by weight). | ||||||
|  |         :param batch_size:  Not used in practice, so can be left blank. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         if isinstance(documents[0], Document): | ||||||
|  |             return self.predict("", documents=documents, top_k=top_k)  # type: ignore | ||||||
|  |         nested_docs = [] | ||||||
|  |         for docs in documents: | ||||||
|  |             results = self.predict("", documents=docs, top_k=top_k)  # type: ignore | ||||||
|  |             nested_docs.append(results) | ||||||
|  | 
 | ||||||
|  |         return nested_docs | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _calculate_rrf(rank: int, k: int = 61) -> float: | ||||||
|  |         """ | ||||||
|  |         Calculates the reciprocal rank fusion. The constant K is set to 61 (60 was suggested by the original paper, | ||||||
|  |         plus 1 as python lists are 0-based and the paper [https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf] used 1-based ranking). | ||||||
|  |         """ | ||||||
|  |         return 1 / (k + rank) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _calc_recentness_score(rank: int, amount: int) -> float: | ||||||
|  |         """ | ||||||
|  |         Calculate recentness score as a linear score between most recent and oldest document. | ||||||
|  |         This linear scaling is useful to | ||||||
|  |           a) reduce the effect of outliers and | ||||||
|  |           b) create recentness scoress that are meaningfully distributed in [0,1], | ||||||
|  |              similar to scores coming from a retriever/ranker. | ||||||
|  |         """ | ||||||
|  |         return (amount - rank) / amount | ||||||
| @ -1,12 +1,16 @@ | |||||||
| import pytest | import pytest | ||||||
| import math | import math | ||||||
|  | import warnings | ||||||
|  | import logging | ||||||
|  | import copy | ||||||
| from unittest.mock import patch | from unittest.mock import patch | ||||||
| 
 | 
 | ||||||
| import torch | import torch | ||||||
| from haystack.schema import Document | from haystack.schema import Document | ||||||
| from haystack.nodes.ranker.base import BaseRanker | from haystack.nodes.ranker.base import BaseRanker | ||||||
| from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker | from haystack.nodes.ranker import SentenceTransformersRanker, CohereRanker | ||||||
| from haystack.errors import HaystackError | from haystack.nodes.ranker.recentness_ranker import RecentnessRanker | ||||||
|  | from haystack.errors import HaystackError, NodeError | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| @ -499,3 +503,282 @@ def test_cohere_ranker_batch_multiple_queries_multiple_doc_lists(docs, mock_cohe | |||||||
|     assert isinstance(results[0], list) |     assert isinstance(results[0], list) | ||||||
|     assert results[0][0] == docs[4] |     assert results[0][0] == docs[4] | ||||||
|     assert results[1][0] == docs[4] |     assert results[1][0] == docs[4] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | recency_tests_inputs = [ | ||||||
|  |     # Score ranking mode works as expected | ||||||
|  |     pytest.param( | ||||||
|  |         { | ||||||
|  |             "docs": [ | ||||||
|  |                 {"meta": {"date": "2021-02-11"}, "score": 0.3, "id": "1"}, | ||||||
|  |                 {"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"}, | ||||||
|  |                 {"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"}, | ||||||
|  |             ], | ||||||
|  |             "weight": 0.5, | ||||||
|  |             "date_meta_field": "date", | ||||||
|  |             "top_k": 2, | ||||||
|  |             "ranking_mode": "score", | ||||||
|  |             "expected_scores": {"1": 0.4833333333333333, "2": 0.7}, | ||||||
|  |             "expected_order": ["2", "1"], | ||||||
|  |             "expected_logs": [], | ||||||
|  |             "expected_warning": "", | ||||||
|  |         }, | ||||||
|  |         id="Score ranking mode works as expected", | ||||||
|  |     ), | ||||||
|  |     # RRF ranking mode works as expected | ||||||
|  |     pytest.param( | ||||||
|  |         { | ||||||
|  |             "docs": [ | ||||||
|  |                 {"meta": {"date": "2021-02-11"}, "id": "1"}, | ||||||
|  |                 {"meta": {"date": "2018-02-11"}, "id": "2"}, | ||||||
|  |                 {"meta": {"date": "2020-02-11"}, "id": "3"}, | ||||||
|  |             ], | ||||||
|  |             "weight": 0.5, | ||||||
|  |             "date_meta_field": "date", | ||||||
|  |             "top_k": 2, | ||||||
|  |             "ranking_mode": "reciprocal_rank_fusion", | ||||||
|  |             "expected_scores": {"1": 0.01639344262295082, "2": 0.016001024065540194}, | ||||||
|  |             "expected_order": ["1", "2"], | ||||||
|  |             "expected_logs": [], | ||||||
|  |             "expected_warning": "", | ||||||
|  |         }, | ||||||
|  |         id="RRF ranking mode works as expected", | ||||||
|  |     ), | ||||||
|  |     # Wrong field to find the date | ||||||
|  |     pytest.param( | ||||||
|  |         { | ||||||
|  |             "docs": [ | ||||||
|  |                 {"meta": {"data": "2021-02-11"}, "score": 0.3, "id": "1"}, | ||||||
|  |                 {"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"}, | ||||||
|  |                 {"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"}, | ||||||
|  |             ], | ||||||
|  |             "weight": 0.5, | ||||||
|  |             "date_meta_field": "date", | ||||||
|  |             "expected_scores": {"1": 0.3, "2": 0.4, "3": 0.6}, | ||||||
|  |             "expected_order": ["1", "2", "3"], | ||||||
|  |             "expected_exception": NodeError( | ||||||
|  |                 """ | ||||||
|  |                 Param <date_meta_field> was set to 'date', but document(s) 1 do not contain this metadata key.\n | ||||||
|  |                 Please double-check the names of existing metadata fields of your documents \n | ||||||
|  |                 and set <date_meta_field> to the name of the field that contains dates. | ||||||
|  |                 """ | ||||||
|  |             ), | ||||||
|  |             "top_k": 2, | ||||||
|  |             "ranking_mode": "score", | ||||||
|  |         }, | ||||||
|  |         id="Wrong field to find the date", | ||||||
|  |     ), | ||||||
|  |     # Date unparsable | ||||||
|  |     pytest.param( | ||||||
|  |         { | ||||||
|  |             "docs": [ | ||||||
|  |                 {"meta": {"date": "abcd"}, "id": "1"}, | ||||||
|  |                 {"meta": {"date": "2024-02-11"}, "id": "2"}, | ||||||
|  |                 {"meta": {"date": "2020-02-11"}, "id": "3"}, | ||||||
|  |             ], | ||||||
|  |             "weight": 0.5, | ||||||
|  |             "date_meta_field": "date", | ||||||
|  |             "expected_order": ["1", "2", "3"], | ||||||
|  |             "expected_logs": [ | ||||||
|  |                 ( | ||||||
|  |                     "haystack.nodes.ranker.recentness_ranker", | ||||||
|  |                     logging.ERROR, | ||||||
|  |                     """ | ||||||
|  |                 Could not parse date information for dates: abcd - 2024-02-11 - 2020-02-11\n | ||||||
|  |                 Continuing without sorting by date. | ||||||
|  |                 """, | ||||||
|  |                 ) | ||||||
|  |             ], | ||||||
|  |             "top_k": 2, | ||||||
|  |             "ranking_mode": "reciprocal_rank_fusion", | ||||||
|  |         }, | ||||||
|  |         id="Date unparsable", | ||||||
|  |     ), | ||||||
|  |     # Wrong score, outside of bonds | ||||||
|  |     pytest.param( | ||||||
|  |         { | ||||||
|  |             "docs": [ | ||||||
|  |                 {"meta": {"date": "2021-02-11"}, "score": 1.3, "id": "1"}, | ||||||
|  |                 {"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"}, | ||||||
|  |                 {"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"}, | ||||||
|  |             ], | ||||||
|  |             "weight": 0.5, | ||||||
|  |             "date_meta_field": "date", | ||||||
|  |             "top_k": 2, | ||||||
|  |             "ranking_mode": "score", | ||||||
|  |             "expected_scores": {"1": 0.5, "2": 0.7, "3": 0.4666666666666667}, | ||||||
|  |             "expected_order": ["2", "3"], | ||||||
|  |             "expected_warning": ["The score 1.3 for document 1 is outside the [0,1] range; defaulting to 0"], | ||||||
|  |         }, | ||||||
|  |         id="Wrong score, outside of bonds", | ||||||
|  |     ), | ||||||
|  |     # Wrong score, not provided | ||||||
|  |     pytest.param( | ||||||
|  |         { | ||||||
|  |             "docs": [ | ||||||
|  |                 {"meta": {"date": "2021-02-11"}, "id": "1"}, | ||||||
|  |                 {"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"}, | ||||||
|  |                 {"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"}, | ||||||
|  |             ], | ||||||
|  |             "weight": 0.5, | ||||||
|  |             "date_meta_field": "date", | ||||||
|  |             "top_k": 2, | ||||||
|  |             "ranking_mode": "score", | ||||||
|  |             "expected_scores": {"1": 0.5, "2": 0.7, "3": 0.4666666666666667}, | ||||||
|  |             "expected_order": ["2", "3"], | ||||||
|  |             "expected_warning": ["The score was not provided; defaulting to 0"], | ||||||
|  |         }, | ||||||
|  |         id="Wrong score, not provided", | ||||||
|  |     ), | ||||||
|  |     # Wrong ranking mode provided | ||||||
|  |     pytest.param( | ||||||
|  |         { | ||||||
|  |             "docs": [ | ||||||
|  |                 {"meta": {"date": "2021-02-11"}, "id": "1"}, | ||||||
|  |                 {"meta": {"date": "2024-02-11"}, "score": 0.4, "id": "2"}, | ||||||
|  |                 {"meta": {"date": "2020-02-11"}, "score": 0.6, "id": "3"}, | ||||||
|  |             ], | ||||||
|  |             "weight": 0.5, | ||||||
|  |             "date_meta_field": "date", | ||||||
|  |             "top_k": 2, | ||||||
|  |             "ranking_mode": "blablabla", | ||||||
|  |             "expected_scores": {"1": 0.01626123744050767, "2": 0.01626123744050767}, | ||||||
|  |             "expected_order": ["1", "2"], | ||||||
|  |             "expected_exception": NodeError( | ||||||
|  |                 """ | ||||||
|  |                 Param <ranking_mode> needs to be 'reciprocal_rank_fusion' or 'score' but was set to 'blablabla'. \n | ||||||
|  |                 Please change the <ranking_mode> when initializing the RecentnessRanker. | ||||||
|  |                 """ | ||||||
|  |             ), | ||||||
|  |         }, | ||||||
|  |         id="Wrong ranking mode provided", | ||||||
|  |     ), | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.unit | ||||||
|  | @pytest.mark.parametrize("test_input", recency_tests_inputs) | ||||||
|  | def test_recentness_ranker(caplog, test_input): | ||||||
|  |     # Create a set of docs | ||||||
|  |     docs = [] | ||||||
|  |     for doc in test_input["docs"]: | ||||||
|  |         docs.append(Document(content="abc", **doc)) | ||||||
|  | 
 | ||||||
|  |     # catch warnings to check they are properly issued | ||||||
|  |     with warnings.catch_warnings(record=True) as warnings_list: | ||||||
|  |         # Initialize the ranker | ||||||
|  |         ranker = RecentnessRanker( | ||||||
|  |             date_meta_field=test_input["date_meta_field"], | ||||||
|  |             ranking_mode=test_input["ranking_mode"], | ||||||
|  |             weight=test_input["weight"], | ||||||
|  |         ) | ||||||
|  |         predict_exception = None | ||||||
|  |         results = [] | ||||||
|  |         try: | ||||||
|  |             results = ranker.predict(query="", documents=docs, top_k=test_input["top_k"]) | ||||||
|  |         except Exception as e: | ||||||
|  |             predict_exception = e | ||||||
|  | 
 | ||||||
|  |         check_results(results, test_input, warnings_list, caplog, predict_exception) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.unit | ||||||
|  | @pytest.mark.parametrize("test_input", recency_tests_inputs) | ||||||
|  | def test_recentness_ranker_batch_list(caplog, test_input): | ||||||
|  |     # Create a set of docs | ||||||
|  |     docs = [] | ||||||
|  |     for doc in test_input["docs"]: | ||||||
|  |         docs.append(Document(content="abc", **doc)) | ||||||
|  | 
 | ||||||
|  |     # catch warnings to check they are properly issued | ||||||
|  |     with warnings.catch_warnings(record=True) as warnings_list: | ||||||
|  |         # Initialize the ranker | ||||||
|  |         ranker = RecentnessRanker( | ||||||
|  |             date_meta_field=test_input["date_meta_field"], | ||||||
|  |             ranking_mode=test_input["ranking_mode"], | ||||||
|  |             weight=test_input["weight"], | ||||||
|  |         ) | ||||||
|  |         predict_exception = None | ||||||
|  |         results = [] | ||||||
|  |         try: | ||||||
|  |             # Run predict_batch with a list as input | ||||||
|  |             results = ranker.predict_batch(queries="", documents=docs, top_k=test_input["top_k"]) | ||||||
|  |         except Exception as e: | ||||||
|  |             predict_exception = e | ||||||
|  | 
 | ||||||
|  |         check_results(results, test_input, warnings_list, caplog, predict_exception) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.unit | ||||||
|  | @pytest.mark.parametrize("test_input", recency_tests_inputs) | ||||||
|  | def test_recentness_ranker_batch_list_of_lists(caplog, test_input): | ||||||
|  |     # Create a set of docs | ||||||
|  |     docs = [] | ||||||
|  |     for doc in test_input["docs"]: | ||||||
|  |         docs.append(Document(content="abc", **doc)) | ||||||
|  | 
 | ||||||
|  |     # catch warnings to check they are properly issued | ||||||
|  |     with warnings.catch_warnings(record=True) as warnings_list: | ||||||
|  |         # Initialize the ranker | ||||||
|  |         ranker = RecentnessRanker( | ||||||
|  |             date_meta_field=test_input["date_meta_field"], | ||||||
|  |             ranking_mode=test_input["ranking_mode"], | ||||||
|  |             weight=test_input["weight"], | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         predict_exception = None | ||||||
|  |         results = [] | ||||||
|  |         try: | ||||||
|  |             # Run predict_batch with a list of lists as input | ||||||
|  |             results = ranker.predict_batch(queries="", documents=[docs, copy.deepcopy(docs)], top_k=test_input["top_k"]) | ||||||
|  |         except Exception as e: | ||||||
|  |             predict_exception = e | ||||||
|  | 
 | ||||||
|  |         check_results(results, test_input, warnings_list, caplog, predict_exception, list_of_lists=True) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def check_results(results, test_input, warnings_list, caplog, exception, list_of_lists=False): | ||||||
|  |     expected_logs_count = 1 | ||||||
|  |     if list_of_lists: | ||||||
|  |         expected_logs_count = 2 | ||||||
|  | 
 | ||||||
|  |     if "expected_exception" in test_input and test_input["expected_exception"] is not None: | ||||||
|  |         assert exception.message == test_input["expected_exception"].message | ||||||
|  |         assert type(exception) == type(test_input["expected_exception"]) | ||||||
|  |         return | ||||||
|  |     else: | ||||||
|  |         assert exception is None | ||||||
|  | 
 | ||||||
|  |     # Check that no warnings were thrown, if we are not expecting any | ||||||
|  |     if "expected_warning" not in test_input or test_input["expected_warning"] == []: | ||||||
|  |         assert len(warnings_list) == 0 | ||||||
|  |     # Check that all expected warnings happened, and only those | ||||||
|  |     else: | ||||||
|  |         assert len(warnings_list) == len(test_input["expected_warning"]) | ||||||
|  |         for i in range(len(warnings_list)): | ||||||
|  |             assert test_input["expected_warning"][int(i)] == str(warnings_list[i].message) | ||||||
|  | 
 | ||||||
|  |     # If we expect logging, compare them one by one | ||||||
|  |     if "expected_logs" not in test_input or test_input["expected_logs"] == []: | ||||||
|  |         assert len(caplog.record_tuples) == 0 | ||||||
|  |     else: | ||||||
|  |         assert expected_logs_count * len(test_input["expected_logs"]) == len(caplog.record_tuples) | ||||||
|  |         for i in range(len(caplog.record_tuples)): | ||||||
|  |             assert test_input["expected_logs"][int(i / expected_logs_count)] == caplog.record_tuples[i] | ||||||
|  | 
 | ||||||
|  |     if not list_of_lists: | ||||||
|  |         check_result_content(results, test_input) | ||||||
|  |     else: | ||||||
|  |         for i in results: | ||||||
|  |             check_result_content(i, test_input) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Verify the results, that the order and the score of the documents match | ||||||
|  | def check_result_content(results, test_input): | ||||||
|  |     assert len(results) == len(test_input["expected_order"]) | ||||||
|  |     for i in range(len(test_input["expected_order"])): | ||||||
|  |         assert test_input["expected_order"][i] == results[i].id | ||||||
|  |         if "expected_scores" in test_input: | ||||||
|  |             assert test_input["expected_scores"][results[i].id] == results[i].score | ||||||
|  |         else: | ||||||
|  |             assert results[i].score is None | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 elundaeva
						elundaeva