satisfy mypy

This commit is contained in:
Timo Moeller 2021-09-13 19:24:43 +02:00
parent 537204e8c9
commit ba7178be7f
2 changed files with 12 additions and 12 deletions

View File

@ -128,7 +128,7 @@ class PredictionHead(nn.Module):
""" """
raise NotImplementedError() raise NotImplementedError()
def logits_to_preds(self, logits): def logits_to_preds(self, logits, span_mask, start_of_word, seq_2_start_t, max_answer_length, **kwargs):
""" """
Implement this function in your special Prediction Head. Implement this function in your special Prediction Head.
Should combine turn logits into predictions. Should combine turn logits into predictions.
@ -350,7 +350,7 @@ class QuestionAnsweringHead(PredictionHead):
return torch.div(logits, self.temperature_for_confidence) return torch.div(logits, self.temperature_for_confidence)
def calibrate_conf(self, logits: List[torch.Tensor], label_all: List[torch.Tensor]): def calibrate_conf(self, logits, label_all):
""" """
Learning a temperature parameter to apply temperature scaling to calibrate confidence scores Learning a temperature parameter to apply temperature scaling to calibrate confidence scores
""" """
@ -462,13 +462,13 @@ class QuestionAnsweringHead(PredictionHead):
return all_top_n return all_top_n
def get_top_candidates(self, sorted_candidates: torch.Tensor, start_end_matrix: torch.Tensor, sample_idx: int, start_matrix: torch.Tensor, end_matrix: torch.Tensor): def get_top_candidates(self, sorted_candidates, start_end_matrix, sample_idx: int, start_matrix, end_matrix):
""" Returns top candidate answers as a list of Span objects. Operates on a matrix of summed start and end logits. """ Returns top candidate answers as a list of Span objects. Operates on a matrix of summed start and end logits.
This matrix corresponds to a single sample (includes special tokens, question tokens, passage tokens). This matrix corresponds to a single sample (includes special tokens, question tokens, passage tokens).
This method always returns a list of len n_best + 1 (it is comprised of the n_best positive answers along with the one no_answer)""" This method always returns a list of len n_best + 1 (it is comprised of the n_best positive answers along with the one no_answer)"""
# Initialize some variables # Initialize some variables
top_candidates = [] top_candidates: List[QACandidate] = []
n_candidates = sorted_candidates.shape[0] n_candidates = sorted_candidates.shape[0]
start_idx_candidates = set() start_idx_candidates = set()
end_idx_candidates = set() end_idx_candidates = set()
@ -496,7 +496,7 @@ class QuestionAnsweringHead(PredictionHead):
answer_type="span", answer_type="span",
offset_unit="token", offset_unit="token",
aggregation_level="passage", aggregation_level="passage",
passage_id=sample_idx, passage_id=str(sample_idx),
confidence=confidence)) confidence=confidence))
if self.duplicate_filtering > -1: if self.duplicate_filtering > -1:
for i in range(0, self.duplicate_filtering + 1): for i in range(0, self.duplicate_filtering + 1):
@ -518,7 +518,7 @@ class QuestionAnsweringHead(PredictionHead):
return top_candidates return top_candidates
def formatted_preds(self, logits: Optional[torch.Tensor] = None, preds: Optional[List[QACandidate]] = None, baskets: Optional[List[SampleBasket]] = None, **kwargs): def formatted_preds(self, preds: List[QACandidate], baskets: List[SampleBasket], logits: Optional[torch.Tensor] = None, **kwargs):
""" Takes a list of passage level predictions, each corresponding to one sample, and converts them into document level """ Takes a list of passage level predictions, each corresponding to one sample, and converts them into document level
predictions. Leverages information in the SampleBaskets. Assumes that we are being passed predictions from predictions. Leverages information in the SampleBaskets. Assumes that we are being passed predictions from
ALL samples in the one SampleBasket i.e. all passages of a document. Logits should be None, because we have ALL samples in the one SampleBasket i.e. all passages of a document. Logits should be None, because we have
@ -533,10 +533,10 @@ class QuestionAnsweringHead(PredictionHead):
if logits or preds is None: if logits or preds is None:
logger.error("QuestionAnsweringHead.formatted_preds() expects preds as input and logits to be None \ logger.error("QuestionAnsweringHead.formatted_preds() expects preds as input and logits to be None \
but was passed something different") but was passed something different")
samples = [s for b in baskets for s in b.samples] samples = [s for b in baskets for s in b.samples] #type: ignore
ids = [s.id for s in samples] ids = [s.id for s in samples]
passage_start_t = [s.features[0]["passage_start_t"] for s in samples] passage_start_t = [s.features[0]["passage_start_t"] for s in samples] #type: ignore
seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples] seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples] #type: ignore
# Aggregate passage level predictions to create document level predictions. # Aggregate passage level predictions to create document level predictions.
# This method assumes that all passages of each document are contained in preds # This method assumes that all passages of each document are contained in preds
@ -552,7 +552,7 @@ class QuestionAnsweringHead(PredictionHead):
return doc_preds return doc_preds
def to_qa_preds(self, top_preds: Tuple[QACandidate], no_ans_gaps: Tuple[float], baskets: Tuple[SampleBasket]): def to_qa_preds(self, top_preds, no_ans_gaps, baskets):
""" Groups Span objects together in a QAPred object """ """ Groups Span objects together in a QAPred object """
ret = [] ret = []
@ -998,7 +998,7 @@ class TextSimilarityHead(PredictionHead):
loss = self.loss_fct(softmax_scores, targets) loss = self.loss_fct(softmax_scores, targets)
return loss return loss
def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor: def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor: # type: ignore
""" """
Returns predicted ranks(similarity) of passages/context for each query Returns predicted ranks(similarity) of passages/context for each query

View File

@ -29,7 +29,7 @@ class QACandidate:
def __init__(self, def __init__(self,
answer_type: str, answer_type: str,
score: str, score: float,
offset_answer_start: int, offset_answer_start: int,
offset_answer_end: int, offset_answer_end: int,
offset_unit: str, offset_unit: str,