diff --git a/haystack/modeling/model/prediction_head.py b/haystack/modeling/model/prediction_head.py index 32f0b5473..e8e7ad4af 100644 --- a/haystack/modeling/model/prediction_head.py +++ b/haystack/modeling/model/prediction_head.py @@ -128,7 +128,7 @@ class PredictionHead(nn.Module): """ raise NotImplementedError() - def logits_to_preds(self, logits): + def logits_to_preds(self, logits, span_mask, start_of_word, seq_2_start_t, max_answer_length, **kwargs): """ Implement this function in your special Prediction Head. Should combine turn logits into predictions. @@ -350,7 +350,7 @@ class QuestionAnsweringHead(PredictionHead): return torch.div(logits, self.temperature_for_confidence) - def calibrate_conf(self, logits: List[torch.Tensor], label_all: List[torch.Tensor]): + def calibrate_conf(self, logits, label_all): """ Learning a temperature parameter to apply temperature scaling to calibrate confidence scores """ @@ -462,13 +462,13 @@ class QuestionAnsweringHead(PredictionHead): return all_top_n - def get_top_candidates(self, sorted_candidates: torch.Tensor, start_end_matrix: torch.Tensor, sample_idx: int, start_matrix: torch.Tensor, end_matrix: torch.Tensor): + def get_top_candidates(self, sorted_candidates, start_end_matrix, sample_idx: int, start_matrix, end_matrix): """ Returns top candidate answers as a list of Span objects. Operates on a matrix of summed start and end logits. This matrix corresponds to a single sample (includes special tokens, question tokens, passage tokens). This method always returns a list of len n_best + 1 (it is comprised of the n_best positive answers along with the one no_answer)""" # Initialize some variables - top_candidates = [] + top_candidates: List[QACandidate] = [] n_candidates = sorted_candidates.shape[0] start_idx_candidates = set() end_idx_candidates = set() @@ -496,7 +496,7 @@ class QuestionAnsweringHead(PredictionHead): answer_type="span", offset_unit="token", aggregation_level="passage", - passage_id=sample_idx, + passage_id=str(sample_idx), confidence=confidence)) if self.duplicate_filtering > -1: for i in range(0, self.duplicate_filtering + 1): @@ -518,7 +518,7 @@ class QuestionAnsweringHead(PredictionHead): return top_candidates - def formatted_preds(self, logits: Optional[torch.Tensor] = None, preds: Optional[List[QACandidate]] = None, baskets: Optional[List[SampleBasket]] = None, **kwargs): + def formatted_preds(self, preds: List[QACandidate], baskets: List[SampleBasket], logits: Optional[torch.Tensor] = None, **kwargs): """ Takes a list of passage level predictions, each corresponding to one sample, and converts them into document level predictions. Leverages information in the SampleBaskets. Assumes that we are being passed predictions from ALL samples in the one SampleBasket i.e. all passages of a document. Logits should be None, because we have @@ -533,10 +533,10 @@ class QuestionAnsweringHead(PredictionHead): if logits or preds is None: logger.error("QuestionAnsweringHead.formatted_preds() expects preds as input and logits to be None \ but was passed something different") - samples = [s for b in baskets for s in b.samples] + samples = [s for b in baskets for s in b.samples] #type: ignore ids = [s.id for s in samples] - passage_start_t = [s.features[0]["passage_start_t"] for s in samples] - seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples] + passage_start_t = [s.features[0]["passage_start_t"] for s in samples] #type: ignore + seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples] #type: ignore # Aggregate passage level predictions to create document level predictions. # This method assumes that all passages of each document are contained in preds @@ -552,7 +552,7 @@ class QuestionAnsweringHead(PredictionHead): return doc_preds - def to_qa_preds(self, top_preds: Tuple[QACandidate], no_ans_gaps: Tuple[float], baskets: Tuple[SampleBasket]): + def to_qa_preds(self, top_preds, no_ans_gaps, baskets): """ Groups Span objects together in a QAPred object """ ret = [] @@ -998,7 +998,7 @@ class TextSimilarityHead(PredictionHead): loss = self.loss_fct(softmax_scores, targets) return loss - def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor: + def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor: # type: ignore """ Returns predicted ranks(similarity) of passages/context for each query diff --git a/haystack/modeling/model/predictions.py b/haystack/modeling/model/predictions.py index 208e9a9ba..5914f232a 100644 --- a/haystack/modeling/model/predictions.py +++ b/haystack/modeling/model/predictions.py @@ -29,7 +29,7 @@ class QACandidate: def __init__(self, answer_type: str, - score: str, + score: float, offset_answer_start: int, offset_answer_end: int, offset_unit: str,