This commit is contained in:
timoeller 2020-02-24 12:28:49 +01:00
parent ef9b99c3cc
commit 0f5b61d20a
2 changed files with 7 additions and 5 deletions

View File

@ -46,7 +46,7 @@ class Finder:
# 3) Apply reader to get granular answer(s)
logger.info(f"Applying the reader now to look for the answer in detail ...")
results = self.reader.predict(question=question,
paragrahps=paragraphs,
paragraphs=paragraphs,
meta_data_paragraphs=meta_data,
top_k=top_k_reader)

View File

@ -145,7 +145,7 @@ class FARMReader:
self.inferencer.model.save(directory)
self.inferencer.processor.save(directory)
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1):
def predict(self, question, paragraphs, meta_data_paragraphs=None, top_k=None, max_processes=1):
"""
Use loaded QA model to find answers for a question in the supplied paragraphs.
@ -175,12 +175,12 @@ class FARMReader:
"""
if meta_data_paragraphs is None:
meta_data_paragraphs = len(paragrahps) * [None]
assert len(paragrahps) == len(meta_data_paragraphs)
meta_data_paragraphs = len(paragraphs) * [None]
assert len(paragraphs) == len(meta_data_paragraphs)
# convert input to FARM format
input_dicts = []
for paragraph, meta_data in zip(paragrahps, meta_data_paragraphs):
for paragraph, meta_data in zip(paragraphs, meta_data_paragraphs):
cur = {"text": paragraph,
"questions": [question],
"document_id": meta_data["document_id"]
@ -202,6 +202,8 @@ class FARMReader:
positive_found = False
for a in pred["predictions"][0]["answers"]:
# skip "no answers" here
# For now we only take one prediction from each passage
# TODO use more predictions per passage when setting n_candidates_per_passage + make FARM predictions more varied
if(not positive_found and a["answer"]):
cur = {"answer": a["answer"],
"score": a["score"],