mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 01:39:45 +00:00 
			
		
		
		
	Added support for unanswerable questions in TransformersReader (#258)
* Added support for unanswerable questions in TransformersReader Co-authored-by: Antonio Lanza <anotniolanza1996@gmail.com>
This commit is contained in:
		
							parent
							
								
									f0d901a374
								
							
						
					
					
						commit
						b55de6f70a
					
				| @ -23,6 +23,7 @@ class TransformersReader(BaseReader): | |||||||
|         context_window_size: int = 30, |         context_window_size: int = 30, | ||||||
|         use_gpu: int = 0, |         use_gpu: int = 0, | ||||||
|         n_best_per_passage: int = 2, |         n_best_per_passage: int = 2, | ||||||
|  |         no_answer: bool = True | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
|         Load a QA model from Transformers. |         Load a QA model from Transformers. | ||||||
| @ -39,11 +40,16 @@ class TransformersReader(BaseReader): | |||||||
|                             The context usually helps users to understand if the answer really makes sense. |                             The context usually helps users to understand if the answer really makes sense. | ||||||
|         :param use_gpu: < 0  -> use cpu |         :param use_gpu: < 0  -> use cpu | ||||||
|                         >= 0 -> ordinal of the gpu to use |                         >= 0 -> ordinal of the gpu to use | ||||||
|  |         :param n_best_per_passage: num of best answers to take into account for each passage | ||||||
|  |         :param no_answer: True -> Hugging Face model could return an "impossible"/"empty" answer (i.e. when there is an unanswerable question) | ||||||
|  |                         False -> otherwise | ||||||
|  | 
 | ||||||
|         """ |         """ | ||||||
|         self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu) |         self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu) | ||||||
|         self.context_window_size = context_window_size |         self.context_window_size = context_window_size | ||||||
|         self.n_best_per_passage = n_best_per_passage |         self.n_best_per_passage = n_best_per_passage | ||||||
|         #TODO param to modify bias for no_answer |         self.no_answer = no_answer | ||||||
|  | 
 | ||||||
|         # TODO context_window_size behaviour different from behavior in FARMReader |         # TODO context_window_size behaviour different from behavior in FARMReader | ||||||
| 
 | 
 | ||||||
|     def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): |     def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): | ||||||
| @ -76,13 +82,12 @@ class TransformersReader(BaseReader): | |||||||
|         answers = [] |         answers = [] | ||||||
|         for doc in documents: |         for doc in documents: | ||||||
|             query = {"context": doc.text, "question": question} |             query = {"context": doc.text, "question": question} | ||||||
|             predictions = self.model(query, topk=self.n_best_per_passage) |             predictions = self.model(query, topk=self.n_best_per_passage,handle_impossible_answer=self.no_answer) | ||||||
|             # for single preds (e.g. via top_k=1) transformers returns a dict instead of a list |             # for single preds (e.g. via top_k=1) transformers returns a dict instead of a list | ||||||
|             if type(predictions) == dict: |             if type(predictions) == dict: | ||||||
|                 predictions = [predictions] |                 predictions = [predictions] | ||||||
|             # assemble and format all answers |             # assemble and format all answers | ||||||
|             for pred in predictions: |             for pred in predictions: | ||||||
|                 if pred["answer"]: |  | ||||||
|                 context_start = max(0, pred["start"] - self.context_window_size) |                 context_start = max(0, pred["start"] - self.context_window_size) | ||||||
|                 context_end = min(len(doc.text), pred["end"] + self.context_window_size) |                 context_end = min(len(doc.text), pred["end"] + self.context_window_size) | ||||||
|                 answers.append({ |                 answers.append({ | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 antoniolanza1996
						antoniolanza1996