mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-07 21:33:39 +00:00
Align TransformersReader with FARMReader (#319)
* Align TransformersReader with FARMReader
This commit is contained in:
parent
72b1013560
commit
3a95fe2006
@ -19,11 +19,13 @@ class TransformersReader(BaseReader):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str = "distilbert-base-uncased-distilled-squad",
|
model: str = "distilbert-base-uncased-distilled-squad",
|
||||||
tokenizer: str = "distilbert-base-uncased",
|
tokenizer: Optional[str] = None,
|
||||||
context_window_size: int = 30,
|
context_window_size: int = 70,
|
||||||
use_gpu: int = 0,
|
use_gpu: int = 0,
|
||||||
top_k_per_candidate: int = 4,
|
top_k_per_candidate: int = 4,
|
||||||
no_answer: bool = True
|
return_no_answers: bool = True,
|
||||||
|
max_seq_len: int = 256,
|
||||||
|
doc_stride: int = 128
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Load a QA model from Transformers.
|
Load a QA model from Transformers.
|
||||||
@ -44,14 +46,20 @@ class TransformersReader(BaseReader):
|
|||||||
Note: - This is not the number of "final answers" you will receive
|
Note: - This is not the number of "final answers" you will receive
|
||||||
(see `top_k` in TransformersReader.predict() or Finder.get_answers() for that)
|
(see `top_k` in TransformersReader.predict() or Finder.get_answers() for that)
|
||||||
- Can includes no_answer in the sorted list of predictions
|
- Can includes no_answer in the sorted list of predictions
|
||||||
:param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
|
:param return_no_answers: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question)
|
||||||
False -> otherwise
|
False -> otherwise
|
||||||
|
no_answer_boost is unfortunately not available with TransformersReader. If you would like to
|
||||||
|
set no_answer_boost, use a FARMReader
|
||||||
|
:param max_seq_len: max sequence length of one input text for the model
|
||||||
|
:param doc_stride: length of striding window for splitting long texts (used if len(text) > max_seq_len)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
|
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
|
||||||
self.context_window_size = context_window_size
|
self.context_window_size = context_window_size
|
||||||
self.top_k_per_candidate = top_k_per_candidate
|
self.top_k_per_candidate = top_k_per_candidate
|
||||||
self.return_no_answers = no_answer
|
self.return_no_answers = return_no_answers
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.doc_stride = doc_stride
|
||||||
|
|
||||||
# TODO context_window_size behaviour different from behavior in FARMReader
|
# TODO context_window_size behaviour different from behavior in FARMReader
|
||||||
|
|
||||||
@ -87,7 +95,11 @@ class TransformersReader(BaseReader):
|
|||||||
best_overall_score = 0
|
best_overall_score = 0
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
query = {"context": doc.text, "question": question}
|
query = {"context": doc.text, "question": question}
|
||||||
predictions = self.model(query, topk=self.top_k_per_candidate, handle_impossible_answer=self.return_no_answers)
|
predictions = self.model(query,
|
||||||
|
topk=self.top_k_per_candidate,
|
||||||
|
handle_impossible_answer=self.return_no_answers,
|
||||||
|
max_seq_len=self.max_seq_len,
|
||||||
|
doc_stride=self.doc_stride)
|
||||||
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
|
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
|
||||||
if type(predictions) == dict:
|
if type(predictions) == dict:
|
||||||
predictions = [predictions]
|
predictions = [predictions]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user