mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-19 13:13:39 +00:00
satisfy mypy
This commit is contained in:
parent
537204e8c9
commit
ba7178be7f
@ -128,7 +128,7 @@ class PredictionHead(nn.Module):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def logits_to_preds(self, logits):
|
def logits_to_preds(self, logits, span_mask, start_of_word, seq_2_start_t, max_answer_length, **kwargs):
|
||||||
"""
|
"""
|
||||||
Implement this function in your special Prediction Head.
|
Implement this function in your special Prediction Head.
|
||||||
Should combine turn logits into predictions.
|
Should combine turn logits into predictions.
|
||||||
@ -350,7 +350,7 @@ class QuestionAnsweringHead(PredictionHead):
|
|||||||
return torch.div(logits, self.temperature_for_confidence)
|
return torch.div(logits, self.temperature_for_confidence)
|
||||||
|
|
||||||
|
|
||||||
def calibrate_conf(self, logits: List[torch.Tensor], label_all: List[torch.Tensor]):
|
def calibrate_conf(self, logits, label_all):
|
||||||
"""
|
"""
|
||||||
Learning a temperature parameter to apply temperature scaling to calibrate confidence scores
|
Learning a temperature parameter to apply temperature scaling to calibrate confidence scores
|
||||||
"""
|
"""
|
||||||
@ -462,13 +462,13 @@ class QuestionAnsweringHead(PredictionHead):
|
|||||||
|
|
||||||
return all_top_n
|
return all_top_n
|
||||||
|
|
||||||
def get_top_candidates(self, sorted_candidates: torch.Tensor, start_end_matrix: torch.Tensor, sample_idx: int, start_matrix: torch.Tensor, end_matrix: torch.Tensor):
|
def get_top_candidates(self, sorted_candidates, start_end_matrix, sample_idx: int, start_matrix, end_matrix):
|
||||||
""" Returns top candidate answers as a list of Span objects. Operates on a matrix of summed start and end logits.
|
""" Returns top candidate answers as a list of Span objects. Operates on a matrix of summed start and end logits.
|
||||||
This matrix corresponds to a single sample (includes special tokens, question tokens, passage tokens).
|
This matrix corresponds to a single sample (includes special tokens, question tokens, passage tokens).
|
||||||
This method always returns a list of len n_best + 1 (it is comprised of the n_best positive answers along with the one no_answer)"""
|
This method always returns a list of len n_best + 1 (it is comprised of the n_best positive answers along with the one no_answer)"""
|
||||||
|
|
||||||
# Initialize some variables
|
# Initialize some variables
|
||||||
top_candidates = []
|
top_candidates: List[QACandidate] = []
|
||||||
n_candidates = sorted_candidates.shape[0]
|
n_candidates = sorted_candidates.shape[0]
|
||||||
start_idx_candidates = set()
|
start_idx_candidates = set()
|
||||||
end_idx_candidates = set()
|
end_idx_candidates = set()
|
||||||
@ -496,7 +496,7 @@ class QuestionAnsweringHead(PredictionHead):
|
|||||||
answer_type="span",
|
answer_type="span",
|
||||||
offset_unit="token",
|
offset_unit="token",
|
||||||
aggregation_level="passage",
|
aggregation_level="passage",
|
||||||
passage_id=sample_idx,
|
passage_id=str(sample_idx),
|
||||||
confidence=confidence))
|
confidence=confidence))
|
||||||
if self.duplicate_filtering > -1:
|
if self.duplicate_filtering > -1:
|
||||||
for i in range(0, self.duplicate_filtering + 1):
|
for i in range(0, self.duplicate_filtering + 1):
|
||||||
@ -518,7 +518,7 @@ class QuestionAnsweringHead(PredictionHead):
|
|||||||
|
|
||||||
return top_candidates
|
return top_candidates
|
||||||
|
|
||||||
def formatted_preds(self, logits: Optional[torch.Tensor] = None, preds: Optional[List[QACandidate]] = None, baskets: Optional[List[SampleBasket]] = None, **kwargs):
|
def formatted_preds(self, preds: List[QACandidate], baskets: List[SampleBasket], logits: Optional[torch.Tensor] = None, **kwargs):
|
||||||
""" Takes a list of passage level predictions, each corresponding to one sample, and converts them into document level
|
""" Takes a list of passage level predictions, each corresponding to one sample, and converts them into document level
|
||||||
predictions. Leverages information in the SampleBaskets. Assumes that we are being passed predictions from
|
predictions. Leverages information in the SampleBaskets. Assumes that we are being passed predictions from
|
||||||
ALL samples in the one SampleBasket i.e. all passages of a document. Logits should be None, because we have
|
ALL samples in the one SampleBasket i.e. all passages of a document. Logits should be None, because we have
|
||||||
@ -533,10 +533,10 @@ class QuestionAnsweringHead(PredictionHead):
|
|||||||
if logits or preds is None:
|
if logits or preds is None:
|
||||||
logger.error("QuestionAnsweringHead.formatted_preds() expects preds as input and logits to be None \
|
logger.error("QuestionAnsweringHead.formatted_preds() expects preds as input and logits to be None \
|
||||||
but was passed something different")
|
but was passed something different")
|
||||||
samples = [s for b in baskets for s in b.samples]
|
samples = [s for b in baskets for s in b.samples] #type: ignore
|
||||||
ids = [s.id for s in samples]
|
ids = [s.id for s in samples]
|
||||||
passage_start_t = [s.features[0]["passage_start_t"] for s in samples]
|
passage_start_t = [s.features[0]["passage_start_t"] for s in samples] #type: ignore
|
||||||
seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples]
|
seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples] #type: ignore
|
||||||
|
|
||||||
# Aggregate passage level predictions to create document level predictions.
|
# Aggregate passage level predictions to create document level predictions.
|
||||||
# This method assumes that all passages of each document are contained in preds
|
# This method assumes that all passages of each document are contained in preds
|
||||||
@ -552,7 +552,7 @@ class QuestionAnsweringHead(PredictionHead):
|
|||||||
|
|
||||||
return doc_preds
|
return doc_preds
|
||||||
|
|
||||||
def to_qa_preds(self, top_preds: Tuple[QACandidate], no_ans_gaps: Tuple[float], baskets: Tuple[SampleBasket]):
|
def to_qa_preds(self, top_preds, no_ans_gaps, baskets):
|
||||||
""" Groups Span objects together in a QAPred object """
|
""" Groups Span objects together in a QAPred object """
|
||||||
ret = []
|
ret = []
|
||||||
|
|
||||||
@ -998,7 +998,7 @@ class TextSimilarityHead(PredictionHead):
|
|||||||
loss = self.loss_fct(softmax_scores, targets)
|
loss = self.loss_fct(softmax_scores, targets)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor:
|
def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor: # type: ignore
|
||||||
"""
|
"""
|
||||||
Returns predicted ranks(similarity) of passages/context for each query
|
Returns predicted ranks(similarity) of passages/context for each query
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ class QACandidate:
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
answer_type: str,
|
answer_type: str,
|
||||||
score: str,
|
score: float,
|
||||||
offset_answer_start: int,
|
offset_answer_start: int,
|
||||||
offset_answer_end: int,
|
offset_answer_end: int,
|
||||||
offset_unit: str,
|
offset_unit: str,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user