diff --git a/haystack/reader/transformers.py b/haystack/reader/transformers.py index 197c7a248..c43527323 100644 --- a/haystack/reader/transformers.py +++ b/haystack/reader/transformers.py @@ -23,6 +23,7 @@ class TransformersReader(BaseReader): context_window_size: int = 30, use_gpu: int = 0, n_best_per_passage: int = 2, + no_answer: bool = True ): """ Load a QA model from Transformers. @@ -39,11 +40,16 @@ class TransformersReader(BaseReader): The context usually helps users to understand if the answer really makes sense. :param use_gpu: < 0 -> use cpu >= 0 -> ordinal of the gpu to use + :param n_best_per_passage: num of best answers to take into account for each passage + :param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question) + False -> otherwise + """ self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu) self.context_window_size = context_window_size self.n_best_per_passage = n_best_per_passage - #TODO param to modify bias for no_answer + self.no_answer = no_answer + # TODO context_window_size behaviour different from behavior in FARMReader def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): @@ -76,25 +82,24 @@ class TransformersReader(BaseReader): answers = [] for doc in documents: query = {"context": doc.text, "question": question} - predictions = self.model(query, topk=self.n_best_per_passage) + predictions = self.model(query, topk=self.n_best_per_passage,handle_impossible_answer=self.no_answer) # for single preds (e.g. via top_k=1) transformers returns a dict instead of a list if type(predictions) == dict: predictions = [predictions] # assemble and format all answers for pred in predictions: - if pred["answer"]: - context_start = max(0, pred["start"] - self.context_window_size) - context_end = min(len(doc.text), pred["end"] + self.context_window_size) - answers.append({ - "answer": pred["answer"], - "context": doc.text[context_start:context_end], - "offset_start": pred["start"], - "offset_end": pred["end"], - "probability": pred["score"], - "score": None, - "document_id": doc.id, - "meta": doc.meta - }) + context_start = max(0, pred["start"] - self.context_window_size) + context_end = min(len(doc.text), pred["end"] + self.context_window_size) + answers.append({ + "answer": pred["answer"], + "context": doc.text[context_start:context_end], + "offset_start": pred["start"], + "offset_end": pred["end"], + "probability": pred["score"], + "score": None, + "document_id": doc.id, + "meta": doc.meta + }) # sort answers by their `probability` and select top-k answers = sorted(