mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-27 09:04:11 +00:00
Merge pull request #331 from deepset-ai/robust_eval
More robust Reader eval by limiting max answers and creating no answer labels
This commit is contained in:
commit
0ad22d5038
@ -45,6 +45,7 @@ def eval_data_from_file(filename: str) -> Tuple[List[Document], List[Label]]:
|
|||||||
|
|
||||||
# Get Labels
|
# Get Labels
|
||||||
for qa in paragraph["qas"]:
|
for qa in paragraph["qas"]:
|
||||||
|
if len(qa["answers"]) > 0:
|
||||||
for answer in qa["answers"]:
|
for answer in qa["answers"]:
|
||||||
label = Label(
|
label = Label(
|
||||||
question=qa["question"],
|
question=qa["question"],
|
||||||
@ -57,7 +58,18 @@ def eval_data_from_file(filename: str) -> Tuple[List[Document], List[Label]]:
|
|||||||
origin="gold_label",
|
origin="gold_label",
|
||||||
)
|
)
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
|
else:
|
||||||
|
label = Label(
|
||||||
|
question=qa["question"],
|
||||||
|
answer="",
|
||||||
|
is_correct_answer=True,
|
||||||
|
is_correct_document=True,
|
||||||
|
document_id=cur_doc.id,
|
||||||
|
offset_start_in_doc=0,
|
||||||
|
no_answer=qa["is_impossible"],
|
||||||
|
origin="gold_label",
|
||||||
|
)
|
||||||
|
labels.append(label)
|
||||||
return docs, labels
|
return docs, labels
|
||||||
|
|
||||||
|
|
||||||
|
@ -394,6 +394,11 @@ class FARMReader(BaseReader):
|
|||||||
:param doc_index: Index/Table name where documents that are used for evaluation are stored
|
:param doc_index: Index/Table name where documents that are used for evaluation are stored
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.top_k_per_candidate != 4:
|
||||||
|
logger.info(f"Performing Evaluation using top_k_per_candidate = {self.top_k_per_candidate} \n"
|
||||||
|
f"and consequently, QuestionAnsweringPredictionHead.n_best = {self.top_k_per_candidate + 1}. \n"
|
||||||
|
f"This deviates from FARM's default where QuestionAnsweringPredictionHead.n_best = 5")
|
||||||
|
|
||||||
# extract all questions for evaluation
|
# extract all questions for evaluation
|
||||||
filters = {"origin": [label_origin]}
|
filters = {"origin": [label_origin]}
|
||||||
|
|
||||||
@ -409,7 +414,8 @@ class FARMReader(BaseReader):
|
|||||||
|
|
||||||
# Create squad style dicts
|
# Create squad style dicts
|
||||||
d: Dict[str, Any] = {}
|
d: Dict[str, Any] = {}
|
||||||
for doc_id in aggregated_per_doc.keys():
|
all_doc_ids = [x.id for x in document_store.get_all_documents(doc_index)]
|
||||||
|
for doc_id in all_doc_ids:
|
||||||
doc = document_store.get_document_by_id(doc_id, index=doc_index)
|
doc = document_store.get_document_by_id(doc_id, index=doc_index)
|
||||||
if not doc:
|
if not doc:
|
||||||
logger.error(f"Document with the ID '{doc_id}' is not present in the document store.")
|
logger.error(f"Document with the ID '{doc_id}' is not present in the document store.")
|
||||||
@ -419,9 +425,13 @@ class FARMReader(BaseReader):
|
|||||||
}
|
}
|
||||||
# get all questions / answers
|
# get all questions / answers
|
||||||
aggregated_per_question: Dict[str, Any] = defaultdict(list)
|
aggregated_per_question: Dict[str, Any] = defaultdict(list)
|
||||||
|
if doc_id in aggregated_per_doc:
|
||||||
for label in aggregated_per_doc[doc_id]:
|
for label in aggregated_per_doc[doc_id]:
|
||||||
# add to existing answers
|
# add to existing answers
|
||||||
if label.question in aggregated_per_question.keys():
|
if label.question in aggregated_per_question.keys():
|
||||||
|
# Hack to fix problem where duplicate questions are merged by doc_store processing creating a QA example with 8 annotations > 6 annotation max
|
||||||
|
if len(aggregated_per_question[label.question]["answers"]) >= 6:
|
||||||
|
continue
|
||||||
aggregated_per_question[label.question]["answers"].append({
|
aggregated_per_question[label.question]["answers"].append({
|
||||||
"text": label.answer,
|
"text": label.answer,
|
||||||
"answer_start": label.offset_start_in_doc})
|
"answer_start": label.offset_start_in_doc})
|
||||||
|
@ -11,7 +11,7 @@ def test_add_eval_data(document_store):
|
|||||||
document_store.add_eval_data(filename="samples/squad/small.json", doc_index="test_eval_document", label_index="test_feedback")
|
document_store.add_eval_data(filename="samples/squad/small.json", doc_index="test_eval_document", label_index="test_feedback")
|
||||||
|
|
||||||
assert document_store.get_document_count(index="test_eval_document") == 87
|
assert document_store.get_document_count(index="test_eval_document") == 87
|
||||||
assert document_store.get_label_count(index="test_feedback") == 881
|
assert document_store.get_label_count(index="test_feedback") == 1214
|
||||||
|
|
||||||
# test documents
|
# test documents
|
||||||
docs = document_store.get_all_documents(index="test_eval_document")
|
docs = document_store.get_all_documents(index="test_eval_document")
|
||||||
|
@ -297,7 +297,7 @@
|
|||||||
"# Initialize Reader\n",
|
"# Initialize Reader\n",
|
||||||
"from haystack.reader.farm import FARMReader\n",
|
"from haystack.reader.farm import FARMReader\n",
|
||||||
"\n",
|
"\n",
|
||||||
"reader = FARMReader(\"deepset/roberta-base-squad2\")"
|
"reader = FARMReader(\"deepset/roberta-base-squad2\", top_k_per_candidate=4)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -75,7 +75,7 @@ retriever = ElasticsearchRetriever(document_store=document_store)
|
|||||||
|
|
||||||
|
|
||||||
# Initialize Reader
|
# Initialize Reader
|
||||||
reader = FARMReader("deepset/roberta-base-squad2")
|
reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4)
|
||||||
|
|
||||||
# Initialize Finder which sticks together Reader and Retriever
|
# Initialize Finder which sticks together Reader and Retriever
|
||||||
finder = Finder(reader, retriever)
|
finder = Finder(reader, retriever)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user