mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 02:09:39 +00:00
Align TransformersReader with FARMReader (#319)
* Align TransformersReader with FARMReader
This commit is contained in:
parent
72b1013560
commit
3a95fe2006
@ -19,11 +19,13 @@ class TransformersReader(BaseReader):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "distilbert-base-uncased-distilled-squad",
|
||||
tokenizer: str = "distilbert-base-uncased",
|
||||
context_window_size: int = 30,
|
||||
tokenizer: Optional[str] = None,
|
||||
context_window_size: int = 70,
|
||||
use_gpu: int = 0,
|
||||
top_k_per_candidate: int = 4,
|
||||
no_answer: bool = True
|
||||
return_no_answers: bool = True,
|
||||
max_seq_len: int = 256,
|
||||
doc_stride: int = 128
|
||||
):
|
||||
"""
|
||||
Load a QA model from Transformers.
|
||||
@ -44,14 +46,20 @@ class TransformersReader(BaseReader):
|
||||
Note: - This is not the number of "final answers" you will receive
|
||||
(see `top_k` in TransformersReader.predict() or Finder.get_answers() for that)
|
||||
- Can includes no_answer in the sorted list of predictions
|
||||
:param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
|
||||
False -> otherwise
|
||||
:param return_no_answers: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
|
||||
False -> otherwise
|
||||
no_answer_boost is unfortunately not available with TransformersReader. If you would like to
|
||||
set no_answer_boost, use a FARMReader
|
||||
:param max_seq_len: max sequence length of one input text for the model
|
||||
:param doc_stride: length of striding window for splitting long texts (used if len(text) > max_seq_len)
|
||||
|
||||
"""
|
||||
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
|
||||
self.context_window_size = context_window_size
|
||||
self.top_k_per_candidate = top_k_per_candidate
|
||||
self.return_no_answers = no_answer
|
||||
self.return_no_answers = return_no_answers
|
||||
self.max_seq_len = max_seq_len
|
||||
self.doc_stride = doc_stride
|
||||
|
||||
# TODO context_window_size behaviour different from behavior in FARMReader
|
||||
|
||||
@ -87,7 +95,11 @@ class TransformersReader(BaseReader):
|
||||
best_overall_score = 0
|
||||
for doc in documents:
|
||||
query = {"context": doc.text, "question": question}
|
||||
predictions = self.model(query, topk=self.top_k_per_candidate, handle_impossible_answer=self.return_no_answers)
|
||||
predictions = self.model(query,
|
||||
topk=self.top_k_per_candidate,
|
||||
handle_impossible_answer=self.return_no_answers,
|
||||
max_seq_len=self.max_seq_len,
|
||||
doc_stride=self.doc_stride)
|
||||
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
|
||||
if type(predictions) == dict:
|
||||
predictions = [predictions]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user