mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-04 03:39:31 +00:00 
			
		
		
		
	satisfy mypy
This commit is contained in:
		
							parent
							
								
									537204e8c9
								
							
						
					
					
						commit
						ba7178be7f
					
				@ -128,7 +128,7 @@ class PredictionHead(nn.Module):
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def logits_to_preds(self, logits):
 | 
			
		||||
    def logits_to_preds(self, logits, span_mask, start_of_word, seq_2_start_t, max_answer_length, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Implement this function in your special Prediction Head.
 | 
			
		||||
        Should combine turn logits into predictions.
 | 
			
		||||
@ -350,7 +350,7 @@ class QuestionAnsweringHead(PredictionHead):
 | 
			
		||||
        return torch.div(logits, self.temperature_for_confidence)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def calibrate_conf(self, logits: List[torch.Tensor], label_all: List[torch.Tensor]):
 | 
			
		||||
    def calibrate_conf(self, logits, label_all):
 | 
			
		||||
        """
 | 
			
		||||
        Learning a temperature parameter to apply temperature scaling to calibrate confidence scores
 | 
			
		||||
        """
 | 
			
		||||
@ -462,13 +462,13 @@ class QuestionAnsweringHead(PredictionHead):
 | 
			
		||||
 | 
			
		||||
        return all_top_n
 | 
			
		||||
 | 
			
		||||
    def get_top_candidates(self, sorted_candidates: torch.Tensor, start_end_matrix: torch.Tensor, sample_idx: int, start_matrix: torch.Tensor, end_matrix: torch.Tensor):
 | 
			
		||||
    def get_top_candidates(self, sorted_candidates, start_end_matrix, sample_idx: int, start_matrix, end_matrix):
 | 
			
		||||
        """ Returns top candidate answers as a list of Span objects. Operates on a matrix of summed start and end logits.
 | 
			
		||||
        This matrix corresponds to a single sample (includes special tokens, question tokens, passage tokens).
 | 
			
		||||
        This method always returns a list of len n_best + 1 (it is comprised of the n_best positive answers along with the one no_answer)"""
 | 
			
		||||
 | 
			
		||||
        # Initialize some variables
 | 
			
		||||
        top_candidates = []
 | 
			
		||||
        top_candidates: List[QACandidate] = []
 | 
			
		||||
        n_candidates = sorted_candidates.shape[0]
 | 
			
		||||
        start_idx_candidates = set()
 | 
			
		||||
        end_idx_candidates = set()
 | 
			
		||||
@ -496,7 +496,7 @@ class QuestionAnsweringHead(PredictionHead):
 | 
			
		||||
                                                  answer_type="span",
 | 
			
		||||
                                                  offset_unit="token",
 | 
			
		||||
                                                  aggregation_level="passage",
 | 
			
		||||
                                                  passage_id=sample_idx,
 | 
			
		||||
                                                  passage_id=str(sample_idx),
 | 
			
		||||
                                                  confidence=confidence))
 | 
			
		||||
                if self.duplicate_filtering > -1:
 | 
			
		||||
                    for i in range(0, self.duplicate_filtering + 1):
 | 
			
		||||
@ -518,7 +518,7 @@ class QuestionAnsweringHead(PredictionHead):
 | 
			
		||||
 | 
			
		||||
        return top_candidates
 | 
			
		||||
 | 
			
		||||
    def formatted_preds(self, logits: Optional[torch.Tensor] = None, preds: Optional[List[QACandidate]] = None, baskets: Optional[List[SampleBasket]] = None, **kwargs):
 | 
			
		||||
    def formatted_preds(self,  preds: List[QACandidate], baskets: List[SampleBasket], logits: Optional[torch.Tensor] = None, **kwargs):
 | 
			
		||||
        """ Takes a list of passage level predictions, each corresponding to one sample, and converts them into document level
 | 
			
		||||
        predictions. Leverages information in the SampleBaskets. Assumes that we are being passed predictions from
 | 
			
		||||
        ALL samples in the one SampleBasket i.e. all passages of a document. Logits should be None, because we have
 | 
			
		||||
@ -533,10 +533,10 @@ class QuestionAnsweringHead(PredictionHead):
 | 
			
		||||
        if logits or preds is None:
 | 
			
		||||
            logger.error("QuestionAnsweringHead.formatted_preds() expects preds as input and logits to be None \
 | 
			
		||||
                            but was passed something different")
 | 
			
		||||
        samples = [s for b in baskets for s in b.samples]
 | 
			
		||||
        samples = [s for b in baskets for s in b.samples] #type: ignore
 | 
			
		||||
        ids = [s.id for s in samples]
 | 
			
		||||
        passage_start_t = [s.features[0]["passage_start_t"] for s in samples]
 | 
			
		||||
        seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples]
 | 
			
		||||
        passage_start_t = [s.features[0]["passage_start_t"] for s in samples] #type: ignore
 | 
			
		||||
        seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples] #type: ignore
 | 
			
		||||
 | 
			
		||||
        # Aggregate passage level predictions to create document level predictions.
 | 
			
		||||
        # This method assumes that all passages of each document are contained in preds
 | 
			
		||||
@ -552,7 +552,7 @@ class QuestionAnsweringHead(PredictionHead):
 | 
			
		||||
 | 
			
		||||
        return doc_preds
 | 
			
		||||
 | 
			
		||||
    def to_qa_preds(self, top_preds: Tuple[QACandidate], no_ans_gaps: Tuple[float], baskets: Tuple[SampleBasket]):
 | 
			
		||||
    def to_qa_preds(self, top_preds, no_ans_gaps, baskets):
 | 
			
		||||
        """ Groups Span objects together in a QAPred object  """
 | 
			
		||||
        ret = []
 | 
			
		||||
 | 
			
		||||
@ -998,7 +998,7 @@ class TextSimilarityHead(PredictionHead):
 | 
			
		||||
        loss = self.loss_fct(softmax_scores, targets)
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor:
 | 
			
		||||
    def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor:  # type: ignore
 | 
			
		||||
        """
 | 
			
		||||
        Returns predicted ranks(similarity) of passages/context for each query
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -29,7 +29,7 @@ class QACandidate:
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 answer_type: str,
 | 
			
		||||
                 score: str,
 | 
			
		||||
                 score: float,
 | 
			
		||||
                 offset_answer_start: int,
 | 
			
		||||
                 offset_answer_end: int,
 | 
			
		||||
                 offset_unit: str,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user