diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 2fd57274e..e04329688 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -53,7 +53,8 @@ class FARMReader(BaseReader): num_processes: Optional[int] = None, max_seq_len: int = 256, doc_stride: int = 128, - progress_bar: bool = True + progress_bar: bool = True, + duplicate_filtering: int = 0 ): """ @@ -91,6 +92,8 @@ class FARMReader(BaseReader): :param doc_stride: Length of striding window for splitting long texts (used if ``len(text) > max_seq_len``) :param progress_bar: Whether to show a tqdm progress bar or not. Can be helpful to disable in production deployments to keep the logs clean. + :param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered. + The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal. """ # save init parameters to enable export of component config as YAML @@ -99,6 +102,7 @@ class FARMReader(BaseReader): batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer, top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample, num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar, + duplicate_filtering=duplicate_filtering ) self.return_no_answers = return_no_answer @@ -116,6 +120,10 @@ class FARMReader(BaseReader): self.inferencer.model.prediction_heads[0].n_best_per_sample = top_k_per_sample except: logger.warning("Could not set `top_k_per_sample` in FARM. Please update FARM version.") + try: + self.inferencer.model.prediction_heads[0].duplicate_filtering = duplicate_filtering + except: + logger.warning("Could not set `duplicate_filtering` in FARM. Please update FARM version.") self.max_seq_len = max_seq_len self.use_gpu = use_gpu self.progress_bar = progress_bar diff --git a/test/test_eval.py b/test/test_eval.py index e8d3091e5..211215f10 100644 --- a/test/test_eval.py +++ b/test/test_eval.py @@ -132,7 +132,7 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): index="haystack_test_eval_document", ) assert eval_retriever.recall == 1.0 - assert eval_reader.top_k_f1 == 0.7 + assert round(eval_reader.top_k_f1, 4) == 0.8333 assert eval_reader.top_k_em == 0.5