mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-16 05:20:51 +00:00
83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
from farm.infer import Inferencer
|
|
import numpy as np
|
|
from scipy.special import expit
|
|
|
|
|
|
class FARMReader:
|
|
"""
|
|
Implementation of FARM Inferencer for Question Answering.
|
|
|
|
The class loads a saved FARM adaptive model from a given directory and runs
|
|
inference using `inference_from_dicts()` method.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_dir,
|
|
context_size=30,
|
|
no_answer_shift=-100,
|
|
batch_size=16,
|
|
use_gpu=True,
|
|
):
|
|
"""
|
|
Load a saved FARM model in Inference mode.
|
|
|
|
:param model_dir: directory path of the saved model
|
|
"""
|
|
self.model = Inferencer.load(model_dir, batch_size=batch_size, gpu=use_gpu)
|
|
self.model.model.prediction_heads[0].context_size = context_size
|
|
self.model.model.prediction_heads[0].no_answer_shift = no_answer_shift
|
|
|
|
def predict(self, input_dicts, top_k=None):
|
|
"""
|
|
Run inference on the loaded model for the given input dicts.
|
|
|
|
:param input_dicts: list of input dicts
|
|
:param top_k: the maximum number of answers to return
|
|
:return:
|
|
"""
|
|
results = self.model.inference_from_dicts(
|
|
dicts=input_dicts, rest_api_schema=True, use_multiprocessing=False
|
|
)
|
|
|
|
# The FARM Inferencer as of now do not support multi document QA.
|
|
# The QA inference is done for each text independently and the
|
|
# results are sorted descending by their `score`.
|
|
|
|
all_predictions = []
|
|
for res in results:
|
|
all_predictions.extend(res["predictions"])
|
|
|
|
all_answers = []
|
|
for pred in all_predictions:
|
|
answers = pred["answers"]
|
|
for a in answers:
|
|
# Two sets of offset fields are returned by FARM -- context level and document level.
|
|
# For the API, only context level offsets are relevant.
|
|
a["offset_start"] = a["offset_answer_start"] - a["offset_context_start"]
|
|
a["offset_end"] = a["offset_context_end"] - a["offset_answer_end"]
|
|
all_answers.extend(answers)
|
|
|
|
# remove all null answers (where an answers in not found in the text)
|
|
all_answers = [ans for ans in all_answers if ans["answer"]]
|
|
|
|
scores = np.asarray([ans["score"] for ans in all_answers])
|
|
probabilities = expit(scores / 8)
|
|
for ans, prob in zip(all_answers, probabilities):
|
|
ans["probability"] = prob
|
|
|
|
# sort answers by their `probability`
|
|
sorted_answers = sorted(
|
|
all_answers, key=lambda k: k["probability"], reverse=True
|
|
)
|
|
|
|
# all predictions here are for the same questions, so the the metadata from
|
|
# the first prediction in the list is taken.
|
|
if all_predictions:
|
|
resp = all_predictions[0] # get the first prediction dict
|
|
resp["answers"] = sorted_answers[:top_k]
|
|
else:
|
|
resp = []
|
|
|
|
return {"results": [resp]}
|