mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-19 21:23:37 +00:00
satisfy mypy
This commit is contained in:
parent
537204e8c9
commit
ba7178be7f
@ -128,7 +128,7 @@ class PredictionHead(nn.Module):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def logits_to_preds(self, logits):
|
||||
def logits_to_preds(self, logits, span_mask, start_of_word, seq_2_start_t, max_answer_length, **kwargs):
|
||||
"""
|
||||
Implement this function in your special Prediction Head.
|
||||
Should combine turn logits into predictions.
|
||||
@ -350,7 +350,7 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
return torch.div(logits, self.temperature_for_confidence)
|
||||
|
||||
|
||||
def calibrate_conf(self, logits: List[torch.Tensor], label_all: List[torch.Tensor]):
|
||||
def calibrate_conf(self, logits, label_all):
|
||||
"""
|
||||
Learning a temperature parameter to apply temperature scaling to calibrate confidence scores
|
||||
"""
|
||||
@ -462,13 +462,13 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
|
||||
return all_top_n
|
||||
|
||||
def get_top_candidates(self, sorted_candidates: torch.Tensor, start_end_matrix: torch.Tensor, sample_idx: int, start_matrix: torch.Tensor, end_matrix: torch.Tensor):
|
||||
def get_top_candidates(self, sorted_candidates, start_end_matrix, sample_idx: int, start_matrix, end_matrix):
|
||||
""" Returns top candidate answers as a list of Span objects. Operates on a matrix of summed start and end logits.
|
||||
This matrix corresponds to a single sample (includes special tokens, question tokens, passage tokens).
|
||||
This method always returns a list of len n_best + 1 (it is comprised of the n_best positive answers along with the one no_answer)"""
|
||||
|
||||
# Initialize some variables
|
||||
top_candidates = []
|
||||
top_candidates: List[QACandidate] = []
|
||||
n_candidates = sorted_candidates.shape[0]
|
||||
start_idx_candidates = set()
|
||||
end_idx_candidates = set()
|
||||
@ -496,7 +496,7 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
answer_type="span",
|
||||
offset_unit="token",
|
||||
aggregation_level="passage",
|
||||
passage_id=sample_idx,
|
||||
passage_id=str(sample_idx),
|
||||
confidence=confidence))
|
||||
if self.duplicate_filtering > -1:
|
||||
for i in range(0, self.duplicate_filtering + 1):
|
||||
@ -518,7 +518,7 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
|
||||
return top_candidates
|
||||
|
||||
def formatted_preds(self, logits: Optional[torch.Tensor] = None, preds: Optional[List[QACandidate]] = None, baskets: Optional[List[SampleBasket]] = None, **kwargs):
|
||||
def formatted_preds(self, preds: List[QACandidate], baskets: List[SampleBasket], logits: Optional[torch.Tensor] = None, **kwargs):
|
||||
""" Takes a list of passage level predictions, each corresponding to one sample, and converts them into document level
|
||||
predictions. Leverages information in the SampleBaskets. Assumes that we are being passed predictions from
|
||||
ALL samples in the one SampleBasket i.e. all passages of a document. Logits should be None, because we have
|
||||
@ -533,10 +533,10 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
if logits or preds is None:
|
||||
logger.error("QuestionAnsweringHead.formatted_preds() expects preds as input and logits to be None \
|
||||
but was passed something different")
|
||||
samples = [s for b in baskets for s in b.samples]
|
||||
samples = [s for b in baskets for s in b.samples] #type: ignore
|
||||
ids = [s.id for s in samples]
|
||||
passage_start_t = [s.features[0]["passage_start_t"] for s in samples]
|
||||
seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples]
|
||||
passage_start_t = [s.features[0]["passage_start_t"] for s in samples] #type: ignore
|
||||
seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples] #type: ignore
|
||||
|
||||
# Aggregate passage level predictions to create document level predictions.
|
||||
# This method assumes that all passages of each document are contained in preds
|
||||
@ -552,7 +552,7 @@ class QuestionAnsweringHead(PredictionHead):
|
||||
|
||||
return doc_preds
|
||||
|
||||
def to_qa_preds(self, top_preds: Tuple[QACandidate], no_ans_gaps: Tuple[float], baskets: Tuple[SampleBasket]):
|
||||
def to_qa_preds(self, top_preds, no_ans_gaps, baskets):
|
||||
""" Groups Span objects together in a QAPred object """
|
||||
ret = []
|
||||
|
||||
@ -998,7 +998,7 @@ class TextSimilarityHead(PredictionHead):
|
||||
loss = self.loss_fct(softmax_scores, targets)
|
||||
return loss
|
||||
|
||||
def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||
def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor: # type: ignore
|
||||
"""
|
||||
Returns predicted ranks(similarity) of passages/context for each query
|
||||
|
||||
|
@ -29,7 +29,7 @@ class QACandidate:
|
||||
|
||||
def __init__(self,
|
||||
answer_type: str,
|
||||
score: str,
|
||||
score: float,
|
||||
offset_answer_start: int,
|
||||
offset_answer_end: int,
|
||||
offset_unit: str,
|
||||
|
Loading…
x
Reference in New Issue
Block a user