This commit is contained in:
timoeller 2020-02-24 12:28:49 +01:00
parent ef9b99c3cc
commit 0f5b61d20a
2 changed files with 7 additions and 5 deletions

View File

@ -46,7 +46,7 @@ class Finder:
# 3) Apply reader to get granular answer(s) # 3) Apply reader to get granular answer(s)
logger.info(f"Applying the reader now to look for the answer in detail ...") logger.info(f"Applying the reader now to look for the answer in detail ...")
results = self.reader.predict(question=question, results = self.reader.predict(question=question,
paragrahps=paragraphs, paragraphs=paragraphs,
meta_data_paragraphs=meta_data, meta_data_paragraphs=meta_data,
top_k=top_k_reader) top_k=top_k_reader)

View File

@ -145,7 +145,7 @@ class FARMReader:
self.inferencer.model.save(directory) self.inferencer.model.save(directory)
self.inferencer.processor.save(directory) self.inferencer.processor.save(directory)
def predict(self, question, paragrahps, meta_data_paragraphs=None, top_k=None, max_processes=1): def predict(self, question, paragraphs, meta_data_paragraphs=None, top_k=None, max_processes=1):
""" """
Use loaded QA model to find answers for a question in the supplied paragraphs. Use loaded QA model to find answers for a question in the supplied paragraphs.
@ -175,12 +175,12 @@ class FARMReader:
""" """
if meta_data_paragraphs is None: if meta_data_paragraphs is None:
meta_data_paragraphs = len(paragrahps) * [None] meta_data_paragraphs = len(paragraphs) * [None]
assert len(paragrahps) == len(meta_data_paragraphs) assert len(paragraphs) == len(meta_data_paragraphs)
# convert input to FARM format # convert input to FARM format
input_dicts = [] input_dicts = []
for paragraph, meta_data in zip(paragrahps, meta_data_paragraphs): for paragraph, meta_data in zip(paragraphs, meta_data_paragraphs):
cur = {"text": paragraph, cur = {"text": paragraph,
"questions": [question], "questions": [question],
"document_id": meta_data["document_id"] "document_id": meta_data["document_id"]
@ -202,6 +202,8 @@ class FARMReader:
positive_found = False positive_found = False
for a in pred["predictions"][0]["answers"]: for a in pred["predictions"][0]["answers"]:
# skip "no answers" here # skip "no answers" here
# For now we only take one prediction from each passage
# TODO use more predictions per passage when setting n_candidates_per_passage + make FARM predictions more varied
if(not positive_found and a["answer"]): if(not positive_found and a["answer"]):
cur = {"answer": a["answer"], cur = {"answer": a["answer"],
"score": a["score"], "score": a["score"],