Filtering duplicate answers (#1021)

* Allow filtering of duplicate answers as implemented in FARM

* Changed default behavior to filtering exact duplicates

* Change expected test result due to filtering of duplicate answers by default

* Rounding expected test results for comparison with predictions
This commit is contained in:
Julian Risch 2021-05-03 17:18:10 +02:00 committed by GitHub
parent ca63f9fee2
commit bf4563e5d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 2 deletions

View File

@ -53,7 +53,8 @@ class FARMReader(BaseReader):
num_processes: Optional[int] = None,
max_seq_len: int = 256,
doc_stride: int = 128,
progress_bar: bool = True
progress_bar: bool = True,
duplicate_filtering: int = 0
):
"""
@ -91,6 +92,8 @@ class FARMReader(BaseReader):
:param doc_stride: Length of striding window for splitting long texts (used if ``len(text) > max_seq_len``)
:param progress_bar: Whether to show a tqdm progress bar or not.
Can be helpful to disable in production deployments to keep the logs clean.
:param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered.
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
"""
# save init parameters to enable export of component config as YAML
@ -99,6 +102,7 @@ class FARMReader(BaseReader):
batch_size=batch_size, use_gpu=use_gpu, no_ans_boost=no_ans_boost, return_no_answer=return_no_answer,
top_k=top_k, top_k_per_candidate=top_k_per_candidate, top_k_per_sample=top_k_per_sample,
num_processes=num_processes, max_seq_len=max_seq_len, doc_stride=doc_stride, progress_bar=progress_bar,
duplicate_filtering=duplicate_filtering
)
self.return_no_answers = return_no_answer
@ -116,6 +120,10 @@ class FARMReader(BaseReader):
self.inferencer.model.prediction_heads[0].n_best_per_sample = top_k_per_sample
except:
logger.warning("Could not set `top_k_per_sample` in FARM. Please update FARM version.")
try:
self.inferencer.model.prediction_heads[0].duplicate_filtering = duplicate_filtering
except:
logger.warning("Could not set `duplicate_filtering` in FARM. Please update FARM version.")
self.max_seq_len = max_seq_len
self.use_gpu = use_gpu
self.progress_bar = progress_bar

View File

@ -132,7 +132,7 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever):
index="haystack_test_eval_document",
)
assert eval_retriever.recall == 1.0
assert eval_reader.top_k_f1 == 0.7
assert round(eval_reader.top_k_f1, 4) == 0.8333
assert eval_reader.top_k_em == 0.5