Align TransformersReader with FARMReader (#319)

* Align TransformersReader with FARMReader
This commit is contained in:
bogdankostic 2020-08-18 14:26:33 +02:00 committed by GitHub
parent 72b1013560
commit 3a95fe2006
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,11 +19,13 @@ class TransformersReader(BaseReader):
def __init__(
self,
model: str = "distilbert-base-uncased-distilled-squad",
tokenizer: str = "distilbert-base-uncased",
context_window_size: int = 30,
tokenizer: Optional[str] = None,
context_window_size: int = 70,
use_gpu: int = 0,
top_k_per_candidate: int = 4,
no_answer: bool = True
return_no_answers: bool = True,
max_seq_len: int = 256,
doc_stride: int = 128
):
"""
Load a QA model from Transformers.
@ -44,14 +46,20 @@ class TransformersReader(BaseReader):
Note: - This is not the number of "final answers" you will receive
(see `top_k` in TransformersReader.predict() or Finder.get_answers() for that)
- Can includes no_answer in the sorted list of predictions
:param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
False -> otherwise
:param return_no_answers: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
False -> otherwise
no_answer_boost is unfortunately not available with TransformersReader. If you would like to
set no_answer_boost, use a FARMReader
:param max_seq_len: max sequence length of one input text for the model
:param doc_stride: length of striding window for splitting long texts (used if len(text) > max_seq_len)
"""
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
self.context_window_size = context_window_size
self.top_k_per_candidate = top_k_per_candidate
self.return_no_answers = no_answer
self.return_no_answers = return_no_answers
self.max_seq_len = max_seq_len
self.doc_stride = doc_stride
# TODO context_window_size behaviour different from behavior in FARMReader
@ -87,7 +95,11 @@ class TransformersReader(BaseReader):
best_overall_score = 0
for doc in documents:
query = {"context": doc.text, "question": question}
predictions = self.model(query, topk=self.top_k_per_candidate, handle_impossible_answer=self.return_no_answers)
predictions = self.model(query,
topk=self.top_k_per_candidate,
handle_impossible_answer=self.return_no_answers,
max_seq_len=self.max_seq_len,
doc_stride=self.doc_stride)
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
if type(predictions) == dict:
predictions = [predictions]