mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
Filtering duplicate answers (#1021)
* Allow filtering of duplicate answers as implemented in FARM * Changed default behavior to filtering exact duplicates * Change expected test result due to filtering of duplicate answers by default * Rounding expected test results for comparison with predictions
This commit is contained in:
parent
ca63f9fee2
commit
bf4563e5d2
@ -53,7 +53,8 @@ class FARMReader(BaseReader):
|
||||
num_processes: Optional[int] = None,
|
||||
max_seq_len: int = 256,
|
||||
doc_stride: int = 128,
|
||||
progress_bar: bool = True
|
||||
progress_bar: bool = True,
|
||||
duplicate_filtering: int = 0
|
||||
):
|
||||
|
||||
"""
|
||||
@ -91,6 +92,8 @@ class FARMReader(BaseReader):
|
||||
:param doc_stride: Length of striding window for splitting long texts (used if ``len(text) > max_seq_len``)
|
||||
:param progress_bar: Whether to show a tqdm progress bar or not.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
:param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered.
|
||||
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
@ -99,6 +102,7 @@ class FARMReader(BaseReader):
|
||||
batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer,
|
||||
top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample,
|
||||
num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar,
|
||||
duplicate_filtering=duplicate_filtering
|
||||
)
|
||||
|
||||
self.return_no_answers = return_no_answer
|
||||
@ -116,6 +120,10 @@ class FARMReader(BaseReader):
|
||||
self.inferencer.model.prediction_heads[0].n_best_per_sample = top_k_per_sample
|
||||
except:
|
||||
logger.warning("Could not set `top_k_per_sample` in FARM. Please update FARM version.")
|
||||
try:
|
||||
self.inferencer.model.prediction_heads[0].duplicate_filtering = duplicate_filtering
|
||||
except:
|
||||
logger.warning("Could not set `duplicate_filtering` in FARM. Please update FARM version.")
|
||||
self.max_seq_len = max_seq_len
|
||||
self.use_gpu = use_gpu
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
@ -132,7 +132,7 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever):
|
||||
index="haystack_test_eval_document",
|
||||
)
|
||||
assert eval_retriever.recall == 1.0
|
||||
assert eval_reader.top_k_f1 == 0.7
|
||||
assert round(eval_reader.top_k_f1, 4) == 0.8333
|
||||
assert eval_reader.top_k_em == 0.5
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user