mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-04 13:53:16 +00:00

* Remove invalid-envvar-default and logging-too-many-args * Remove import-self, access-member-before-definition and deprecated-argument * Remove used-before-assignment by restructuring type import * Remove unneeded-not * Silence unnecessary-lambda (it's necessary) * Remove pointless-string-statement * Update Documentation & Code Style * Silenced unsupported-membership-test (probably a real bug, can't fix though) * Remove trailing-newlines * Remove super-init-not-called and slience invalid-sequence-index (it's valid) * Remove invalid-envvar-default in ui * Remove some more warnings from pyproject.toml than actually solrted in code, CI will fail * Linting all modules together is more readable * Update Documentation & Code Style * Typo in pylint disable comment * Simplify long boolean statement * Simplify init call in FAISS * Fix inconsistent-return-statements * Fix useless-super-delegation * Fix useless-else-on-loop * Fix another inconsistent-return-statements * Move back pylint disable comment moved by black * Fix consider-using-set-comprehension * Fix another consider-using-set-comprehension * Silence non-parent-init-called * Update pylint exclusion list * Update Documentation & Code Style * Resolve unnecessary-else-after-break * Fix superfluous-parens * Fix no-else-break * Remove is_correctly_retrieved along with its pylint issue * Update exclusions list * Silence constructor issue in squad_data.py (method is already broken) * Fix too-many-return-statements * Fix use-dict-literal * Fix consider-using-from-import and useless-object-inheritance * Update exclusion list * Fix simplifiable-if-statements * Fix one consider-using-dict-items * Fix another consider-using-dict-items * Fix a third consider-using-dict-items * Fix last consider-using-dict-items * Fix three use-a-generator * Silence import errors on numba, tensorboardX and apex, but add comments & logs * Fix couple of mypy issues * Fix another typing issue * Silence mypy, was conflicting with more meaningful pylint issue * Fix no-else-continue * Silence unsubscriptable-object and fix an import error with importlib.metadata * Update Documentation & Code Style * Fix all no-else-raise * Update Documentation & Code Style * Fix inverted parameters in simplified if switch * Change [test] to [all] in some jobs (for typing and linting) * Add comment in haystack/schema.py on pydantic's dataclasses * Move comment from get_documents_by_id into _convert_weaviate_result_to_document in weaviate.py * Add comment on pylint silencing * Fix bug introduced rest_api/controller/search.py * Update Documentation & Code Style * Add ADR about Pydantic dataclasses * Update pydantic-dataclasses.md * Add link to Pydantic docs on Dataclasses Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1109 lines
52 KiB
Python
1109 lines
52 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Optional, Union, Dict
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch import optim
|
|
from torch.nn import CrossEntropyLoss, NLLLoss
|
|
from transformers import AutoModelForQuestionAnswering
|
|
|
|
from haystack.modeling.data_handler.samples import SampleBasket
|
|
from haystack.modeling.model.predictions import QACandidate, QAPred
|
|
from haystack.modeling.utils import try_get, all_gather_list
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
try:
|
|
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
|
except (ImportError, AttributeError) as e:
|
|
logger.debug("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
|
|
BertLayerNorm = torch.nn.LayerNorm
|
|
|
|
|
|
class PredictionHead(nn.Module):
|
|
"""
|
|
Takes word embeddings from a language model and generates logits for a given task. Can also convert logits
|
|
to loss and and logits to predictions.
|
|
"""
|
|
|
|
subclasses = {} # type: Dict
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
"""This automatically keeps track of all available subclasses.
|
|
Enables generic load() for all specific PredictionHead implementation.
|
|
"""
|
|
super().__init_subclass__(**kwargs)
|
|
cls.subclasses[cls.__name__] = cls
|
|
|
|
@classmethod
|
|
def create(cls, prediction_head_name: str, layer_dims: List[int], class_weights=Optional[List[float]]):
|
|
"""
|
|
Create subclass of Prediction Head.
|
|
|
|
:param prediction_head_name: Classname (exact string!) of prediction head we want to create
|
|
:param layer_dims: describing the feed forward block structure, e.g. [768,2]
|
|
:param class_weights: The loss weighting to be assigned to certain label classes during training.
|
|
Used to correct cases where there is a strong class imbalance.
|
|
:return: Prediction Head of class prediction_head_name
|
|
"""
|
|
# TODO make we want to make this more generic.
|
|
# 1. Class weights is not relevant for all heads.
|
|
# 2. Layer weights impose FF structure, maybe we want sth else later
|
|
# Solution: We could again use **kwargs
|
|
return cls.subclasses[prediction_head_name](layer_dims=layer_dims, class_weights=class_weights)
|
|
|
|
def save_config(self, save_dir: Union[str, Path], head_num: int = 0):
|
|
"""
|
|
Saves the config as a json file.
|
|
|
|
:param save_dir: Path to save config to
|
|
:param head_num: Which head to save
|
|
"""
|
|
# updating config in case the parameters have been changed
|
|
self.generate_config()
|
|
output_config_file = Path(save_dir) / f"prediction_head_{head_num}_config.json"
|
|
with open(output_config_file, "w") as file:
|
|
json.dump(self.config, file)
|
|
|
|
def save(self, save_dir: Union[str, Path], head_num: int = 0):
|
|
"""
|
|
Saves the prediction head state dict.
|
|
|
|
:param save_dir: path to save prediction head to
|
|
:param head_num: which head to save
|
|
"""
|
|
output_model_file = Path(save_dir) / f"prediction_head_{head_num}.bin"
|
|
torch.save(self.state_dict(), output_model_file)
|
|
self.save_config(save_dir, head_num)
|
|
|
|
def generate_config(self):
|
|
"""
|
|
Generates config file from Class parameters (only for sensible config parameters).
|
|
"""
|
|
config = {}
|
|
for key, value in self.__dict__.items():
|
|
if type(value) is np.ndarray:
|
|
value = value.tolist()
|
|
if _is_json(value) and key[0] != "_":
|
|
config[key] = value
|
|
if self.task_name == "text_similarity" and key == "similarity_function":
|
|
config["similarity_function"] = value
|
|
config["name"] = self.__class__.__name__
|
|
config.pop("config", None)
|
|
self.config = config
|
|
|
|
@classmethod
|
|
def load(cls, config_file: str, strict: bool = True, load_weights: bool = True):
|
|
"""
|
|
Loads a Prediction Head. Infers the class of prediction head from config_file.
|
|
|
|
:param config_file: location where corresponding config is stored
|
|
:param strict: whether to strictly enforce that the keys loaded from saved model match the ones in
|
|
the PredictionHead (see torch.nn.module.load_state_dict()).
|
|
Set to `False` for backwards compatibility with PHs saved with older version of Haystack.
|
|
:param load_weights: whether to load weights of the prediction head
|
|
:return: PredictionHead
|
|
:rtype: PredictionHead[T]
|
|
"""
|
|
config = json.load(open(config_file))
|
|
prediction_head = cls.subclasses[config["name"]](**config)
|
|
if load_weights:
|
|
model_file = cls._get_model_file(config_file=config_file)
|
|
logger.info("Loading prediction head from {}".format(model_file))
|
|
prediction_head.load_state_dict(torch.load(model_file, map_location=torch.device("cpu")), strict=strict)
|
|
return prediction_head
|
|
|
|
def logits_to_loss(self, logits, labels):
|
|
"""
|
|
Implement this function in your special Prediction Head.
|
|
Should combine logits and labels with a loss fct to a per sample loss.
|
|
|
|
:param logits: logits, can vary in shape and type, depending on task
|
|
:param labels: labels, can vary in shape and type, depending on task
|
|
:return: per sample loss as a torch.tensor of shape [batch_size]
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def logits_to_preds(self, logits, span_mask, start_of_word, seq_2_start_t, max_answer_length, **kwargs):
|
|
"""
|
|
Implement this function in your special Prediction Head.
|
|
Should combine turn logits into predictions.
|
|
|
|
:param logits: logits, can vary in shape and type, depending on task
|
|
:return: predictions as a torch.tensor of shape [batch_size]
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def prepare_labels(self, **kwargs):
|
|
"""
|
|
Some prediction heads need additional label conversion.
|
|
|
|
:param kwargs: placeholder for passing generic parameters
|
|
:return: labels in the right format
|
|
:rtype: object
|
|
"""
|
|
# TODO maybe just return **kwargs to not force people to implement this
|
|
raise NotImplementedError()
|
|
|
|
def resize_input(self, input_dim):
|
|
"""
|
|
This function compares the output dimensionality of the language model against the input dimensionality
|
|
of the prediction head. If there is a mismatch, the prediction head will be resized to fit.
|
|
"""
|
|
# Note on pylint disable
|
|
# self.feed_forward's existence seems to be a condition for its own initialization
|
|
# within this class, which is clearly wrong. The only way this code could ever be called is
|
|
# thanks to subclasses initializing self.feed_forward somewhere else; however, this is a
|
|
# very implicit requirement for subclasses, and in general bad design. FIXME when possible.
|
|
if "feed_forward" not in dir(self):
|
|
return
|
|
else:
|
|
old_dims = self.feed_forward.layer_dims # pylint: disable=access-member-before-definition
|
|
if input_dim == old_dims[0]:
|
|
return
|
|
new_dims = [input_dim] + old_dims[1:]
|
|
logger.info(
|
|
f"Resizing input dimensions of {type(self).__name__} ({self.task_name}) "
|
|
f"from {old_dims} to {new_dims} to match language model"
|
|
)
|
|
self.feed_forward = FeedForwardBlock(new_dims)
|
|
self.layer_dims[0] = input_dim
|
|
self.feed_forward.layer_dims[0] = input_dim
|
|
|
|
@classmethod
|
|
def _get_model_file(cls, config_file: Union[str, Path]):
|
|
if "config.json" in str(config_file) and "prediction_head" in str(config_file):
|
|
head_num = int("".join([char for char in os.path.basename(config_file) if char.isdigit()]))
|
|
model_file = Path(os.path.dirname(config_file)) / f"prediction_head_{head_num}.bin"
|
|
else:
|
|
raise ValueError(f"This doesn't seem to be a proper prediction_head config file: '{config_file}'")
|
|
return model_file
|
|
|
|
def _set_name(self, name):
|
|
self.task_name = name
|
|
|
|
|
|
class FeedForwardBlock(nn.Module):
|
|
"""
|
|
A feed forward neural network of variable depth and width.
|
|
"""
|
|
|
|
def __init__(self, layer_dims: List[int], **kwargs):
|
|
# Todo: Consider having just one input argument
|
|
super(FeedForwardBlock, self).__init__()
|
|
self.layer_dims = layer_dims
|
|
# If read from config the input will be string
|
|
n_layers = len(layer_dims) - 1
|
|
layers_all = []
|
|
# TODO: IS this needed?
|
|
self.output_size = layer_dims[-1]
|
|
|
|
for i in range(n_layers):
|
|
size_in = layer_dims[i]
|
|
size_out = layer_dims[i + 1]
|
|
layer = nn.Linear(size_in, size_out)
|
|
layers_all.append(layer)
|
|
self.feed_forward = nn.Sequential(*layers_all)
|
|
|
|
def forward(self, X: torch.Tensor):
|
|
logits = self.feed_forward(X)
|
|
return logits
|
|
|
|
|
|
class QuestionAnsweringHead(PredictionHead):
|
|
"""
|
|
A question answering head predicts the start and end of the answer on token level.
|
|
|
|
In addition, it gives a score for the prediction so that multiple answers can be ranked.
|
|
There are three different kinds of scores available:
|
|
1) (standard) score: the sum of the logits of the start and end index. This score is unbounded because the logits are unbounded.
|
|
It is the default for ranking answers.
|
|
2) confidence score: also based on the logits of the start and end index but scales them to the interval 0 to 1 and incorporates no_answer.
|
|
It can be used for ranking by setting use_confidence_scores_for_ranking to True
|
|
3) calibrated confidence score: same as 2) but divides the logits by a learned temperature_for_confidence parameter
|
|
so that the confidence scores are closer to the model's achieved accuracy. It can be used for ranking by setting
|
|
use_confidence_scores_for_ranking to True and temperature_for_confidence!=1.0. See examples/question_answering_confidence.py for more details.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
layer_dims: List[int] = [768, 2],
|
|
task_name: str = "question_answering",
|
|
no_ans_boost: float = 0.0,
|
|
context_window_size: int = 100,
|
|
n_best: int = 5,
|
|
n_best_per_sample: Optional[int] = None,
|
|
duplicate_filtering: int = -1,
|
|
temperature_for_confidence: float = 1.0,
|
|
use_confidence_scores_for_ranking: bool = False,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
:param layer_dims: dimensions of Feed Forward block, e.g. [768,2], for adjusting to BERT embedding. Output should be always 2
|
|
:param kwargs: placeholder for passing generic parameters
|
|
:param no_ans_boost: How much the no_answer logit is boosted/increased.
|
|
The higher the value, the more likely a "no answer possible given the input text" is returned by the model
|
|
:param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer.
|
|
:param n_best: The number of positive answer spans for each document.
|
|
:param n_best_per_sample: num candidate answer spans to consider from each passage. Each passage also returns "no answer" info.
|
|
This is decoupled from n_best on document level, since predictions on passage level are very similar.
|
|
It should have a low value
|
|
:param duplicate_filtering: Answers are filtered based on their position. Both start and end position of the answers are considered.
|
|
The higher the value, answers that are more apart are filtered out. 0 corresponds to exact duplicates. -1 turns off duplicate removal.
|
|
:param temperature_for_confidence: The divisor that is used to scale logits to calibrate confidence scores
|
|
:param use_confidence_scores_for_ranking: Whether to sort answers by confidence score (normalized between 0 and 1) or by standard score (unbounded)(default).
|
|
"""
|
|
super(QuestionAnsweringHead, self).__init__()
|
|
if len(kwargs) > 0:
|
|
logger.warning(
|
|
f"Some unused parameters are passed to the QuestionAnsweringHead. "
|
|
f"Might not be a problem. Params: {json.dumps(kwargs)}"
|
|
)
|
|
self.layer_dims = layer_dims
|
|
assert self.layer_dims[-1] == 2
|
|
self.feed_forward = FeedForwardBlock(self.layer_dims)
|
|
logger.debug(f"Prediction head initialized with size {self.layer_dims}")
|
|
self.num_labels = self.layer_dims[-1]
|
|
self.ph_output_type = "per_token_squad"
|
|
self.model_type = "span_classification" # predicts start and end token of answer
|
|
self.task_name = task_name
|
|
self.no_ans_boost = no_ans_boost
|
|
self.context_window_size = context_window_size
|
|
self.n_best = n_best
|
|
if n_best_per_sample:
|
|
self.n_best_per_sample = n_best_per_sample
|
|
else:
|
|
# increasing n_best_per_sample to n_best ensures that there are n_best predictions in total
|
|
# otherwise this might not be the case for very short documents with only one "sample"
|
|
self.n_best_per_sample = n_best
|
|
self.duplicate_filtering = duplicate_filtering
|
|
self.generate_config()
|
|
self.temperature_for_confidence = nn.Parameter(torch.ones(1) * temperature_for_confidence)
|
|
self.use_confidence_scores_for_ranking = use_confidence_scores_for_ranking
|
|
|
|
@classmethod
|
|
def load(cls, pretrained_model_name_or_path: Union[str, Path], revision: Optional[str] = None, **kwargs): # type: ignore
|
|
"""
|
|
Load a prediction head from a saved Haystack or transformers model. `pretrained_model_name_or_path`
|
|
can be one of the following:
|
|
a) Local path to a Haystack prediction head config (e.g. my-bert/prediction_head_0_config.json)
|
|
b) Local path to a Transformers model (e.g. my-bert)
|
|
c) Name of a public model from https://huggingface.co/models (e.g. distilbert-base-uncased-distilled-squad)
|
|
|
|
|
|
:param pretrained_model_name_or_path: local path of a saved model or name of a publicly available model.
|
|
Exemplary public names:
|
|
- distilbert-base-uncased-distilled-squad
|
|
- bert-large-uncased-whole-word-masking-finetuned-squad
|
|
|
|
See https://huggingface.co/models for full list
|
|
:param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
|
"""
|
|
if (
|
|
os.path.exists(pretrained_model_name_or_path)
|
|
and "config.json" in str(pretrained_model_name_or_path)
|
|
and "prediction_head" in str(pretrained_model_name_or_path)
|
|
):
|
|
# a) Haystack style
|
|
super(QuestionAnsweringHead, cls).load(str(pretrained_model_name_or_path))
|
|
else:
|
|
# b) transformers style
|
|
# load all weights from model
|
|
full_qa_model = AutoModelForQuestionAnswering.from_pretrained(
|
|
pretrained_model_name_or_path, revision=revision, **kwargs
|
|
)
|
|
# init empty head
|
|
head = cls(layer_dims=[full_qa_model.config.hidden_size, 2], task_name="question_answering")
|
|
# transfer weights for head from full model
|
|
head.feed_forward.feed_forward[0].load_state_dict(full_qa_model.qa_outputs.state_dict())
|
|
del full_qa_model
|
|
|
|
return head
|
|
|
|
def forward(self, X: torch.Tensor):
|
|
"""
|
|
One forward pass through the prediction head model, starting with language model output on token level.
|
|
"""
|
|
logits = self.feed_forward(X)
|
|
return self.temperature_scale(logits)
|
|
|
|
def logits_to_loss(self, logits: torch.Tensor, labels: torch.Tensor, **kwargs):
|
|
"""
|
|
Combine predictions and labels to a per sample loss.
|
|
"""
|
|
# todo explain how we only use first answer for train
|
|
# labels.shape = [batch_size, n_max_answers, 2]. n_max_answers is by default 6 since this is the
|
|
# most that occurs in the SQuAD dev set. The 2 in the final dimension corresponds to [start, end]
|
|
start_position = labels[:, 0, 0]
|
|
end_position = labels[:, 0, 1]
|
|
|
|
# logits is of shape [batch_size, max_seq_len, 2]. Like above, the final dimension corresponds to [start, end]
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1)
|
|
end_logits = end_logits.squeeze(-1)
|
|
|
|
# Squeeze final singleton dimensions
|
|
if len(start_position.size()) > 1:
|
|
start_position = start_position.squeeze(-1)
|
|
if len(end_position.size()) > 1:
|
|
end_position = end_position.squeeze(-1)
|
|
|
|
ignored_index = start_logits.size(1)
|
|
start_position.clamp_(0, ignored_index)
|
|
end_position.clamp_(0, ignored_index)
|
|
|
|
# Workaround for pytorch bug in version 1.10.0 with non-continguous tensors
|
|
# Fix expected in 1.10.1 based on https://github.com/pytorch/pytorch/pull/64954
|
|
start_logits = start_logits.contiguous()
|
|
start_position = start_position.contiguous()
|
|
end_logits = end_logits.contiguous()
|
|
end_position = end_position.contiguous()
|
|
|
|
loss_fct = CrossEntropyLoss(reduction="none")
|
|
start_loss = loss_fct(start_logits, start_position)
|
|
end_loss = loss_fct(end_logits, end_position)
|
|
per_sample_loss = (start_loss + end_loss) / 2
|
|
return per_sample_loss
|
|
|
|
def temperature_scale(self, logits: torch.Tensor):
|
|
return torch.div(logits, self.temperature_for_confidence)
|
|
|
|
def calibrate_conf(self, logits, label_all):
|
|
"""
|
|
Learning a temperature parameter to apply temperature scaling to calibrate confidence scores
|
|
"""
|
|
logits = torch.cat(logits, dim=0)
|
|
|
|
# To handle no_answer labels correctly (-1,-1), we set their start_position to 0. The logit at index 0 also refers to no_answer
|
|
# TODO some language models do not have the CLS token at position 0. For these models, we need to map start_position==-1 to the index of CLS token
|
|
start_position = [label[0][0] if label[0][0] >= 0 else 0 for label in label_all]
|
|
end_position = [label[0][1] if label[0][1] >= 0 else 0 for label in label_all]
|
|
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1)
|
|
end_logits = end_logits.squeeze(-1)
|
|
|
|
start_position = torch.tensor(start_position)
|
|
if len(start_position.size()) > 1:
|
|
start_position = start_position.squeeze(-1)
|
|
end_position = torch.tensor(end_position)
|
|
if len(end_position.size()) > 1:
|
|
end_position = end_position.squeeze(-1)
|
|
|
|
ignored_index = start_logits.size(1) - 1
|
|
start_position.clamp_(0, ignored_index)
|
|
end_position.clamp_(0, ignored_index)
|
|
|
|
nll_criterion = CrossEntropyLoss()
|
|
|
|
optimizer = optim.LBFGS([self.temperature_for_confidence], lr=0.01, max_iter=50)
|
|
|
|
def eval_start_end_logits():
|
|
loss = nll_criterion(
|
|
self.temperature_scale(start_logits), start_position.to(device=start_logits.device)
|
|
) + nll_criterion(self.temperature_scale(end_logits), end_position.to(device=end_logits.device))
|
|
loss.backward()
|
|
return loss
|
|
|
|
optimizer.step(eval_start_end_logits)
|
|
|
|
def logits_to_preds(
|
|
self,
|
|
logits: torch.Tensor,
|
|
span_mask: torch.Tensor,
|
|
start_of_word: torch.Tensor,
|
|
seq_2_start_t: torch.Tensor,
|
|
max_answer_length: int = 1000,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Get the predicted index of start and end token of the answer. Note that the output is at token level
|
|
and not word level. Note also that these logits correspond to the tokens of a sample
|
|
(i.e. special tokens, question tokens, passage_tokens)
|
|
"""
|
|
|
|
# Will be populated with the top-n predictions of each sample in the batch
|
|
# shape = batch_size x ~top_n
|
|
# Note that ~top_n = n if no_answer is within the top_n predictions
|
|
# ~top_n = n+1 if no_answer is not within the top_n predictions
|
|
all_top_n = []
|
|
|
|
# logits is of shape [batch_size, max_seq_len, 2]. The final dimension corresponds to [start, end]
|
|
start_logits, end_logits = logits.split(1, dim=-1)
|
|
start_logits = start_logits.squeeze(-1)
|
|
end_logits = end_logits.squeeze(-1)
|
|
|
|
# Calculate a few useful variables
|
|
batch_size = start_logits.size()[0]
|
|
max_seq_len = start_logits.shape[1] # target dim
|
|
|
|
# get scores for all combinations of start and end logits => candidate answers
|
|
start_matrix = start_logits.unsqueeze(2).expand(-1, -1, max_seq_len)
|
|
end_matrix = end_logits.unsqueeze(1).expand(-1, max_seq_len, -1)
|
|
start_end_matrix = start_matrix + end_matrix
|
|
|
|
# disqualify answers where end < start
|
|
# (set the lower triangular matrix to low value, excluding diagonal)
|
|
indices = torch.tril_indices(max_seq_len, max_seq_len, offset=-1, device=start_end_matrix.device)
|
|
start_end_matrix[:, indices[0][:], indices[1][:]] = -888
|
|
|
|
# disqualify answers where answer span is greater than max_answer_length
|
|
# (set the upper triangular matrix to low value, excluding diagonal)
|
|
indices_long_span = torch.triu_indices(
|
|
max_seq_len, max_seq_len, offset=max_answer_length, device=start_end_matrix.device
|
|
)
|
|
start_end_matrix[:, indices_long_span[0][:], indices_long_span[1][:]] = -777
|
|
|
|
# disqualify answers where start=0, but end != 0
|
|
start_end_matrix[:, 0, 1:] = -666
|
|
|
|
# Turn 1d span_mask vectors into 2d span_mask along 2 different axes
|
|
# span mask has:
|
|
# 0 for every position that is never a valid start or end index (question tokens, mid and end special tokens, padding)
|
|
# 1 everywhere else
|
|
span_mask_start = span_mask.unsqueeze(2).expand(-1, -1, max_seq_len)
|
|
span_mask_end = span_mask.unsqueeze(1).expand(-1, max_seq_len, -1)
|
|
span_mask_2d = span_mask_start + span_mask_end
|
|
# disqualify spans where either start or end is on an invalid token
|
|
invalid_indices = torch.nonzero((span_mask_2d != 2), as_tuple=True)
|
|
start_end_matrix[invalid_indices[0][:], invalid_indices[1][:], invalid_indices[2][:]] = -999
|
|
|
|
# Sort the candidate answers by their score. Sorting happens on the flattened matrix.
|
|
# flat_sorted_indices.shape: (batch_size, max_seq_len^2, 1)
|
|
flat_scores = start_end_matrix.view(batch_size, -1)
|
|
flat_sorted_indices_2d = flat_scores.sort(descending=True)[1]
|
|
flat_sorted_indices = flat_sorted_indices_2d.unsqueeze(2)
|
|
|
|
# The returned indices are then converted back to the original dimensionality of the matrix.
|
|
# sorted_candidates.shape : (batch_size, max_seq_len^2, 2)
|
|
start_indices = flat_sorted_indices // max_seq_len
|
|
end_indices = flat_sorted_indices % max_seq_len
|
|
sorted_candidates = torch.cat((start_indices, end_indices), dim=2)
|
|
|
|
# Get the n_best candidate answers for each sample
|
|
for sample_idx in range(batch_size):
|
|
sample_top_n = self.get_top_candidates(
|
|
sorted_candidates[sample_idx],
|
|
start_end_matrix[sample_idx],
|
|
sample_idx,
|
|
start_matrix=start_matrix[sample_idx],
|
|
end_matrix=end_matrix[sample_idx],
|
|
)
|
|
all_top_n.append(sample_top_n)
|
|
|
|
return all_top_n
|
|
|
|
def get_top_candidates(self, sorted_candidates, start_end_matrix, sample_idx: int, start_matrix, end_matrix):
|
|
"""
|
|
Returns top candidate answers as a list of Span objects. Operates on a matrix of summed start and end logits.
|
|
This matrix corresponds to a single sample (includes special tokens, question tokens, passage tokens).
|
|
This method always returns a list of len n_best + 1 (it is comprised of the n_best positive answers along with the one no_answer)
|
|
"""
|
|
# Initialize some variables
|
|
top_candidates: List[QACandidate] = []
|
|
n_candidates = sorted_candidates.shape[0]
|
|
start_idx_candidates = set()
|
|
end_idx_candidates = set()
|
|
|
|
start_matrix_softmax_start = torch.softmax(start_matrix[:, 0], dim=-1)
|
|
end_matrix_softmax_end = torch.softmax(end_matrix[0, :], dim=-1)
|
|
# Iterate over all candidates and break when we have all our n_best candidates
|
|
for candidate_idx in range(n_candidates):
|
|
if len(top_candidates) == self.n_best_per_sample:
|
|
break
|
|
|
|
# Retrieve candidate's indices
|
|
start_idx = sorted_candidates[candidate_idx, 0].item()
|
|
end_idx = sorted_candidates[candidate_idx, 1].item()
|
|
# Ignore no_answer scores which will be extracted later in this method
|
|
if start_idx == 0 and end_idx == 0:
|
|
continue
|
|
if self.duplicate_filtering > -1 and (start_idx in start_idx_candidates or end_idx in end_idx_candidates):
|
|
continue
|
|
score = start_end_matrix[start_idx, end_idx].item()
|
|
confidence = (start_matrix_softmax_start[start_idx].item() + end_matrix_softmax_end[end_idx].item()) / 2
|
|
top_candidates.append(
|
|
QACandidate(
|
|
offset_answer_start=start_idx,
|
|
offset_answer_end=end_idx,
|
|
score=score,
|
|
answer_type="span",
|
|
offset_unit="token",
|
|
aggregation_level="passage",
|
|
passage_id=str(sample_idx),
|
|
confidence=confidence,
|
|
)
|
|
)
|
|
if self.duplicate_filtering > -1:
|
|
for i in range(0, self.duplicate_filtering + 1):
|
|
start_idx_candidates.add(start_idx + i)
|
|
start_idx_candidates.add(start_idx - i)
|
|
end_idx_candidates.add(end_idx + i)
|
|
end_idx_candidates.add(end_idx - i)
|
|
|
|
no_answer_score = start_end_matrix[0, 0].item()
|
|
no_answer_confidence = (start_matrix_softmax_start[0].item() + end_matrix_softmax_end[0].item()) / 2
|
|
top_candidates.append(
|
|
QACandidate(
|
|
offset_answer_start=0,
|
|
offset_answer_end=0,
|
|
score=no_answer_score,
|
|
answer_type="no_answer",
|
|
offset_unit="token",
|
|
aggregation_level="passage",
|
|
passage_id=None,
|
|
confidence=no_answer_confidence,
|
|
)
|
|
)
|
|
|
|
return top_candidates
|
|
|
|
def formatted_preds(
|
|
self, preds: List[QACandidate], baskets: List[SampleBasket], logits: Optional[torch.Tensor] = None, **kwargs
|
|
):
|
|
"""
|
|
Takes a list of passage level predictions, each corresponding to one sample, and converts them into document level
|
|
predictions. Leverages information in the SampleBaskets. Assumes that we are being passed predictions from
|
|
ALL samples in the one SampleBasket i.e. all passages of a document. Logits should be None, because we have
|
|
already converted the logits to predictions before calling formatted_preds.
|
|
(see Inferencer._get_predictions_and_aggregate()).
|
|
"""
|
|
# Unpack some useful variables
|
|
# passage_start_t is the token index of the passage relative to the document (usually a multiple of doc_stride)
|
|
# seq_2_start_t is the token index of the first token in passage relative to the input sequence (i.e. number of
|
|
# special tokens and question tokens that come before the passage tokens)
|
|
if logits or preds is None:
|
|
logger.error(
|
|
"QuestionAnsweringHead.formatted_preds() expects preds as input and logits to be None \
|
|
but was passed something different"
|
|
)
|
|
samples = [s for b in baskets for s in b.samples] # type: ignore
|
|
ids = [s.id for s in samples]
|
|
passage_start_t = [s.features[0]["passage_start_t"] for s in samples] # type: ignore
|
|
seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples] # type: ignore
|
|
|
|
# Aggregate passage level predictions to create document level predictions.
|
|
# This method assumes that all passages of each document are contained in preds
|
|
# i.e. that there are no incomplete documents. The output of this step
|
|
# are prediction spans
|
|
preds_d = self.aggregate_preds(preds, passage_start_t, ids, seq_2_start_t)
|
|
|
|
# Separate top_preds list from the no_ans_gap float.
|
|
top_preds, no_ans_gaps = zip(*preds_d)
|
|
|
|
# Takes document level prediction spans and returns string predictions
|
|
doc_preds = self.to_qa_preds(top_preds, no_ans_gaps, baskets)
|
|
|
|
return doc_preds
|
|
|
|
def to_qa_preds(self, top_preds, no_ans_gaps, baskets):
|
|
"""
|
|
Groups Span objects together in a QAPred object
|
|
"""
|
|
ret = []
|
|
|
|
# Iterate over each set of document level prediction
|
|
for pred_d, no_ans_gap, basket in zip(top_preds, no_ans_gaps, baskets):
|
|
|
|
# Unpack document offsets, clear text and id
|
|
token_offsets = basket.raw["document_offsets"]
|
|
pred_id = basket.id_external if basket.id_external else basket.id_internal
|
|
|
|
# These options reflect the different input dicts that can be assigned to the basket
|
|
# before any kind of normalization or preprocessing can happen
|
|
question_names = ["question_text", "qas", "questions"]
|
|
doc_names = ["document_text", "context", "text"]
|
|
|
|
document_text = try_get(doc_names, basket.raw)
|
|
question = self.get_question(question_names, basket.raw)
|
|
ground_truth = self.get_ground_truth(basket)
|
|
|
|
curr_doc_pred = QAPred(
|
|
id=pred_id,
|
|
prediction=pred_d,
|
|
context=document_text,
|
|
question=question,
|
|
token_offsets=token_offsets,
|
|
context_window_size=self.context_window_size,
|
|
aggregation_level="document",
|
|
ground_truth_answer=ground_truth,
|
|
no_answer_gap=no_ans_gap,
|
|
)
|
|
|
|
ret.append(curr_doc_pred)
|
|
return ret
|
|
|
|
@staticmethod
|
|
def get_ground_truth(basket: SampleBasket):
|
|
if "answers" in basket.raw:
|
|
return basket.raw["answers"]
|
|
elif "annotations" in basket.raw:
|
|
return basket.raw["annotations"]
|
|
else:
|
|
return None
|
|
|
|
@staticmethod
|
|
def get_question(question_names: List[str], raw_dict: Dict):
|
|
# For NQ style dicts
|
|
qa_name = None
|
|
if "qas" in raw_dict:
|
|
qa_name = "qas"
|
|
elif "question" in raw_dict:
|
|
qa_name = "question"
|
|
if qa_name:
|
|
if type(raw_dict[qa_name][0]) == dict:
|
|
return raw_dict[qa_name][0]["question"]
|
|
return try_get(question_names, raw_dict)
|
|
|
|
def aggregate_preds(self, preds, passage_start_t, ids, seq_2_start_t=None, labels=None):
|
|
"""
|
|
Aggregate passage level predictions to create document level predictions.
|
|
This method assumes that all passages of each document are contained in preds
|
|
i.e. that there are no incomplete documents. The output of this step
|
|
are prediction spans. No answer is represented by a (-1, -1) span on the document level
|
|
"""
|
|
# Initialize some variables
|
|
n_samples = len(preds)
|
|
all_basket_preds = {}
|
|
all_basket_labels = {}
|
|
|
|
# Iterate over the preds of each sample - remove final number which is the sample id and not needed for aggregation
|
|
for sample_idx in range(n_samples):
|
|
basket_id = ids[sample_idx]
|
|
basket_id = basket_id.split("-")[:-1]
|
|
basket_id = "-".join(basket_id)
|
|
|
|
# curr_passage_start_t is the token offset of the current passage
|
|
# It will always be a multiple of doc_stride
|
|
curr_passage_start_t = passage_start_t[sample_idx]
|
|
|
|
# This is to account for the fact that all model input sequences start with some special tokens
|
|
# and also the question tokens before passage tokens.
|
|
if seq_2_start_t:
|
|
cur_seq_2_start_t = seq_2_start_t[sample_idx]
|
|
curr_passage_start_t -= cur_seq_2_start_t
|
|
|
|
# Converts the passage level predictions+labels to document level predictions+labels. Note
|
|
# that on the passage level a no answer is (0,0) but at document level it is (-1,-1) since (0,0)
|
|
# would refer to the first token of the document
|
|
pred_d = self.pred_to_doc_idxs(preds[sample_idx], curr_passage_start_t)
|
|
if labels:
|
|
label_d = self.label_to_doc_idxs(labels[sample_idx], curr_passage_start_t)
|
|
|
|
# Initialize the basket_id as a key in the all_basket_preds and all_basket_labels dictionaries
|
|
if basket_id not in all_basket_preds:
|
|
all_basket_preds[basket_id] = []
|
|
all_basket_labels[basket_id] = []
|
|
|
|
# Add predictions and labels to dictionary grouped by their basket_ids
|
|
all_basket_preds[basket_id].append(pred_d)
|
|
if labels:
|
|
all_basket_labels[basket_id].append(label_d)
|
|
|
|
# Pick n-best predictions and remove repeated labels
|
|
all_basket_preds = {k: self.reduce_preds(v) for k, v in all_basket_preds.items()}
|
|
if labels:
|
|
all_basket_labels = {k: self.reduce_labels(v) for k, v in all_basket_labels.items()}
|
|
|
|
# Return aggregated predictions in order as a list of lists
|
|
keys = [k for k in all_basket_preds]
|
|
aggregated_preds = [all_basket_preds[k] for k in keys]
|
|
if labels:
|
|
labels = [all_basket_labels[k] for k in keys]
|
|
return aggregated_preds, labels
|
|
else:
|
|
return aggregated_preds
|
|
|
|
@staticmethod
|
|
def reduce_labels(labels):
|
|
"""
|
|
Removes repeat answers. Represents a no answer label as (-1,-1)
|
|
"""
|
|
positive_answers = [(start, end) for x in labels for start, end in x if not (start == -1 and end == -1)]
|
|
if not positive_answers:
|
|
return [(-1, -1)]
|
|
else:
|
|
return list(set(positive_answers))
|
|
|
|
def reduce_preds(self, preds):
|
|
"""
|
|
This function contains the logic for choosing the best answers from each passage. In the end, it
|
|
returns the n_best predictions on the document level.
|
|
"""
|
|
# Initialize variables
|
|
passage_no_answer = []
|
|
passage_best_score = []
|
|
passage_best_confidence = []
|
|
no_answer_scores = []
|
|
no_answer_confidences = []
|
|
n_samples = len(preds)
|
|
|
|
# Iterate over the top predictions for each sample
|
|
for sample_idx, sample_preds in enumerate(preds):
|
|
best_pred = sample_preds[0]
|
|
best_pred_score = best_pred.score
|
|
best_pred_confidence = best_pred.confidence
|
|
no_answer_score, no_answer_confidence = self.get_no_answer_score_and_confidence(sample_preds)
|
|
no_answer_score += self.no_ans_boost
|
|
# TODO we might want to apply some kind of a no_ans_boost to no_answer_confidence too
|
|
no_answer = no_answer_score > best_pred_score
|
|
passage_no_answer.append(no_answer)
|
|
no_answer_scores.append(no_answer_score)
|
|
no_answer_confidences.append(no_answer_confidence)
|
|
passage_best_score.append(best_pred_score)
|
|
passage_best_confidence.append(best_pred_confidence)
|
|
|
|
# Get all predictions in flattened list and sort by score
|
|
pos_answers_flat = []
|
|
for sample_idx, passage_preds in enumerate(preds):
|
|
for qa_candidate in passage_preds:
|
|
if not (qa_candidate.offset_answer_start == -1 and qa_candidate.offset_answer_end == -1):
|
|
pos_answers_flat.append(
|
|
QACandidate(
|
|
offset_answer_start=qa_candidate.offset_answer_start,
|
|
offset_answer_end=qa_candidate.offset_answer_end,
|
|
score=qa_candidate.score,
|
|
answer_type=qa_candidate.answer_type,
|
|
offset_unit="token",
|
|
aggregation_level="document",
|
|
passage_id=str(sample_idx),
|
|
n_passages_in_doc=n_samples,
|
|
confidence=qa_candidate.confidence,
|
|
)
|
|
)
|
|
|
|
# TODO add switch for more variation in answers, e.g. if varied_ans then never return overlapping answers
|
|
pos_answer_dedup = self.deduplicate(pos_answers_flat)
|
|
|
|
# This is how much no_ans_boost needs to change to turn a no_answer to a positive answer (or vice versa)
|
|
no_ans_gap = -min([nas - pbs for nas, pbs in zip(no_answer_scores, passage_best_score)])
|
|
no_ans_gap_confidence = -min([nas - pbs for nas, pbs in zip(no_answer_confidences, passage_best_confidence)])
|
|
|
|
# "no answer" scores and positive answers scores are difficult to compare, because
|
|
# + a positive answer score is related to a specific text qa_candidate
|
|
# - a "no answer" score is related to all input texts
|
|
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
|
|
# the most significant difference between scores.
|
|
# Most significant difference: change top prediction from "no answer" to answer (or vice versa)
|
|
best_overall_positive_score = max(x.score for x in pos_answer_dedup)
|
|
best_overall_positive_confidence = max(x.confidence for x in pos_answer_dedup)
|
|
no_answer_pred = QACandidate(
|
|
offset_answer_start=-1,
|
|
offset_answer_end=-1,
|
|
score=best_overall_positive_score - no_ans_gap,
|
|
answer_type="no_answer",
|
|
offset_unit="token",
|
|
aggregation_level="document",
|
|
passage_id=None,
|
|
n_passages_in_doc=n_samples,
|
|
confidence=best_overall_positive_confidence - no_ans_gap_confidence,
|
|
)
|
|
|
|
# Add no answer to positive answers, sort the order and return the n_best
|
|
n_preds = [no_answer_pred] + pos_answer_dedup
|
|
n_preds_sorted = sorted(
|
|
n_preds, key=lambda x: x.confidence if self.use_confidence_scores_for_ranking else x.score, reverse=True
|
|
)
|
|
n_preds_reduced = n_preds_sorted[: self.n_best]
|
|
return n_preds_reduced, no_ans_gap
|
|
|
|
@staticmethod
|
|
def deduplicate(flat_pos_answers):
|
|
# Remove duplicate spans that might be twice predicted in two different passages
|
|
seen = {}
|
|
for qa_answer in flat_pos_answers:
|
|
if (qa_answer.offset_answer_start, qa_answer.offset_answer_end) not in seen:
|
|
seen[(qa_answer.offset_answer_start, qa_answer.offset_answer_end)] = qa_answer
|
|
else:
|
|
seen_score = seen[(qa_answer.offset_answer_start, qa_answer.offset_answer_end)].score
|
|
if qa_answer.score > seen_score:
|
|
seen[(qa_answer.offset_answer_start, qa_answer.offset_answer_end)] = qa_answer
|
|
return list(seen.values())
|
|
|
|
@staticmethod
|
|
def get_no_answer_score_and_confidence(preds):
|
|
for qa_answer in preds:
|
|
start = qa_answer.offset_answer_start
|
|
end = qa_answer.offset_answer_end
|
|
score = qa_answer.score
|
|
confidence = qa_answer.confidence
|
|
if start == -1 and end == -1:
|
|
return score, confidence
|
|
raise Exception
|
|
|
|
@staticmethod
|
|
def pred_to_doc_idxs(pred, passage_start_t):
|
|
"""
|
|
Converts the passage level predictions to document level predictions. Note that on the doc level we
|
|
don't have special tokens or question tokens. This means that a no answer
|
|
cannot be prepresented by a (0,0) qa_answer but will instead be represented by (-1, -1)
|
|
"""
|
|
new_pred = []
|
|
for qa_answer in pred:
|
|
start = qa_answer.offset_answer_start
|
|
end = qa_answer.offset_answer_end
|
|
if start == 0:
|
|
start = -1
|
|
else:
|
|
start += passage_start_t
|
|
if start < 0:
|
|
logger.error("Start token index < 0 (document level)")
|
|
if end == 0:
|
|
end = -1
|
|
else:
|
|
end += passage_start_t
|
|
if end < 0:
|
|
logger.error("End token index < 0 (document level)")
|
|
qa_answer.to_doc_level(start, end)
|
|
new_pred.append(qa_answer)
|
|
return new_pred
|
|
|
|
@staticmethod
|
|
def label_to_doc_idxs(label, passage_start_t):
|
|
"""
|
|
Converts the passage level labels to document level labels. Note that on the doc level we
|
|
don't have special tokens or question tokens. This means that a no answer
|
|
cannot be prepresented by a (0,0) span but will instead be represented by (-1, -1)
|
|
"""
|
|
new_label = []
|
|
for start, end in label:
|
|
# If there is a valid label
|
|
if start > 0 or end > 0:
|
|
new_label.append((start + passage_start_t, end + passage_start_t))
|
|
# If the label is a no answer, we represent this as a (-1, -1) span
|
|
# since there is no CLS token on the document level
|
|
if start == 0 and end == 0:
|
|
new_label.append((-1, -1))
|
|
return new_label
|
|
|
|
def prepare_labels(self, labels, start_of_word, **kwargs):
|
|
return labels
|
|
|
|
|
|
class TextSimilarityHead(PredictionHead):
|
|
"""
|
|
Trains a head on predicting the similarity of two texts like in Dense Passage Retrieval.
|
|
"""
|
|
|
|
def __init__(self, similarity_function: str = "dot_product", global_loss_buffer_size: int = 150000, **kwargs):
|
|
"""
|
|
Init the TextSimilarityHead.
|
|
|
|
:param similarity_function: Function to calculate similarity between queries and passage embeddings.
|
|
Choose either "dot_product" (Default) or "cosine".
|
|
:param global_loss_buffer_size: Buffer size for all_gather() in DDP.
|
|
Increase if errors like "encoded data exceeds max_size ..." come up
|
|
:param kwargs:
|
|
"""
|
|
super(TextSimilarityHead, self).__init__()
|
|
|
|
self.similarity_function = similarity_function
|
|
self.loss_fct = NLLLoss(reduction="mean")
|
|
self.task_name = "text_similarity"
|
|
self.model_type = "text_similarity"
|
|
self.ph_output_type = "per_sequence"
|
|
self.global_loss_buffer_size = global_loss_buffer_size
|
|
self.generate_config()
|
|
|
|
@classmethod
|
|
def dot_product_scores(cls, query_vectors: torch.Tensor, passage_vectors: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculates dot product similarity scores for two 2-dimensional tensors
|
|
|
|
:param query_vectors: tensor of query embeddings from BiAdaptive model
|
|
of dimension n1 x D,
|
|
where n1 is the number of queries/batch size and D is embedding size
|
|
:param passage_vectors: tensor of context/passage embeddings from BiAdaptive model
|
|
of dimension n2 x D,
|
|
where n2 is (batch_size * num_positives) + (batch_size * num_hard_negatives)
|
|
and D is embedding size
|
|
|
|
:return: dot_product: similarity score of each query with each context/passage (dimension: n1xn2)
|
|
"""
|
|
# q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
|
|
dot_product = torch.matmul(query_vectors, torch.transpose(passage_vectors, 0, 1))
|
|
return dot_product
|
|
|
|
@classmethod
|
|
def cosine_scores(cls, query_vectors: torch.Tensor, passage_vectors: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculates cosine similarity scores for two 2-dimensional tensors
|
|
|
|
:param query_vectors: tensor of query embeddings from BiAdaptive model
|
|
of dimension n1 x D,
|
|
where n1 is the number of queries/batch size and D is embedding size
|
|
:param passage_vectors: tensor of context/passage embeddings from BiAdaptive model
|
|
of dimension n2 x D,
|
|
where n2 is (batch_size * num_positives) + (batch_size * num_hard_negatives)
|
|
and D is embedding size
|
|
|
|
:return: cosine similarity score of each query with each context/passage (dimension: n1xn2)
|
|
"""
|
|
# q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
|
|
cosine_similarities = []
|
|
passages_per_batch = passage_vectors.shape[0]
|
|
for query_vector in query_vectors:
|
|
query_vector_repeated = query_vector.repeat(passages_per_batch, 1)
|
|
current_cosine_similarities = nn.functional.cosine_similarity(query_vector_repeated, passage_vectors, dim=1)
|
|
cosine_similarities.append(current_cosine_similarities)
|
|
return torch.stack(cosine_similarities)
|
|
|
|
def get_similarity_function(self):
|
|
"""
|
|
Returns the type of similarity function used to compare queries and passages/contexts
|
|
"""
|
|
if "dot_product" in self.similarity_function:
|
|
return TextSimilarityHead.dot_product_scores
|
|
elif "cosine" in self.similarity_function:
|
|
return TextSimilarityHead.cosine_scores
|
|
else:
|
|
raise AttributeError(
|
|
f"The similarity function can only be 'dot_product' or 'cosine', not '{self.similarity_function}'"
|
|
)
|
|
|
|
def forward(self, query_vectors: torch.Tensor, passage_vectors: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Only packs the embeddings from both language models into a tuple. No further modification.
|
|
The similarity calculation is handled later to enable distributed training (DDP)
|
|
while keeping the support for in-batch negatives.
|
|
(Gather all embeddings from nodes => then do similarity scores + loss)
|
|
|
|
:param query_vectors: Tensor of query embeddings from BiAdaptive model
|
|
of dimension n1 x D,
|
|
where n1 is the number of queries/batch size and D is embedding size
|
|
:param passage_vectors: Tensor of context/passage embeddings from BiAdaptive model
|
|
of dimension n2 x D,
|
|
where n2 is the number of queries/batch size and D is embedding size
|
|
"""
|
|
return query_vectors, passage_vectors
|
|
|
|
def _embeddings_to_scores(self, query_vectors: torch.Tensor, passage_vectors: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Calculates similarity scores between all given query_vectors and passage_vectors
|
|
|
|
:param query_vectors: Tensor of queries encoded by the query encoder model
|
|
:param passage_vectors: Tensor of passages encoded by the passage encoder model
|
|
:return: Tensor of log softmax similarity scores of each query with each passage (dimension: n1xn2)
|
|
"""
|
|
sim_func = self.get_similarity_function()
|
|
scores = sim_func(query_vectors, passage_vectors)
|
|
|
|
if len(query_vectors.size()) > 1:
|
|
q_num = query_vectors.size(0)
|
|
scores = scores.view(q_num, -1)
|
|
|
|
softmax_scores = nn.functional.log_softmax(scores, dim=1)
|
|
return softmax_scores
|
|
|
|
def logits_to_loss(self, logits: Tuple[torch.Tensor, torch.Tensor], label_ids, **kwargs): # type: ignore
|
|
"""
|
|
Computes the loss (Default: NLLLoss) by applying a similarity function (Default: dot product) to the input
|
|
tuple of (query_vectors, passage_vectors) and afterwards applying the loss function on similarity scores.
|
|
|
|
:param logits: Tuple of Tensors (query_embedding, passage_embedding) as returned from forward()
|
|
|
|
:return: negative log likelihood loss from similarity scores
|
|
"""
|
|
# Check if DDP is initialized
|
|
try:
|
|
if torch.distributed.is_available():
|
|
rank = torch.distributed.get_rank()
|
|
else:
|
|
rank = -1
|
|
except (AssertionError, RuntimeError):
|
|
rank = -1
|
|
|
|
# Prepare predicted scores
|
|
query_vectors, passage_vectors = logits
|
|
|
|
# Prepare Labels
|
|
positive_idx_per_question = torch.nonzero((label_ids.view(-1) == 1), as_tuple=False)
|
|
|
|
# Gather global embeddings from all distributed nodes (DDP)
|
|
if rank != -1:
|
|
q_vector_to_send = torch.empty_like(query_vectors).cpu().copy_(query_vectors).detach_()
|
|
p_vector_to_send = torch.empty_like(passage_vectors).cpu().copy_(passage_vectors).detach_()
|
|
|
|
global_question_passage_vectors = all_gather_list(
|
|
[q_vector_to_send, p_vector_to_send, positive_idx_per_question], max_size=self.global_loss_buffer_size
|
|
)
|
|
|
|
global_query_vectors = []
|
|
global_passage_vectors = []
|
|
global_positive_idx_per_question = []
|
|
total_passages = 0
|
|
for i, item in enumerate(global_question_passage_vectors):
|
|
q_vector, p_vectors, positive_idx = item
|
|
|
|
if i != rank:
|
|
global_query_vectors.append(q_vector.to(query_vectors.device))
|
|
global_passage_vectors.append(p_vectors.to(passage_vectors.device))
|
|
global_positive_idx_per_question.extend([v + total_passages for v in positive_idx])
|
|
else:
|
|
global_query_vectors.append(query_vectors)
|
|
global_passage_vectors.append(passage_vectors)
|
|
global_positive_idx_per_question.extend([v + total_passages for v in positive_idx_per_question])
|
|
total_passages += p_vectors.size(0)
|
|
|
|
global_query_vectors = torch.cat(global_query_vectors, dim=0) # type: ignore
|
|
global_passage_vectors = torch.cat(global_passage_vectors, dim=0) # type: ignore
|
|
global_positive_idx_per_question = torch.LongTensor(global_positive_idx_per_question) # type: ignore
|
|
else:
|
|
global_query_vectors = query_vectors # type: ignore
|
|
global_passage_vectors = passage_vectors # type: ignore
|
|
global_positive_idx_per_question = positive_idx_per_question # type: ignore
|
|
|
|
# Get similarity scores
|
|
softmax_scores = self._embeddings_to_scores(global_query_vectors, global_passage_vectors) # type: ignore
|
|
targets = global_positive_idx_per_question.squeeze(-1).to(softmax_scores.device) # type: ignore
|
|
|
|
# Calculate loss
|
|
loss = self.loss_fct(softmax_scores, targets)
|
|
return loss
|
|
|
|
def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs) -> torch.Tensor: # type: ignore
|
|
"""
|
|
Returns predicted ranks(similarity) of passages/context for each query
|
|
|
|
:param logits: tensor of log softmax similarity scores of each query with each context/passage (dimension: n1xn2)
|
|
|
|
:return: predicted ranks of passages for each query
|
|
"""
|
|
query_vectors, passage_vectors = logits
|
|
softmax_scores = self._embeddings_to_scores(query_vectors, passage_vectors)
|
|
_, sorted_scores = torch.sort(softmax_scores, dim=1, descending=True)
|
|
return sorted_scores
|
|
|
|
def prepare_labels(self, label_ids, **kwargs) -> torch.Tensor: # type: ignore
|
|
"""
|
|
Returns a tensor with passage labels(0:hard_negative/1:positive) for each query
|
|
|
|
:return: passage labels(0:hard_negative/1:positive) for each query
|
|
"""
|
|
labels = torch.zeros(label_ids.size(0), label_ids.numel())
|
|
|
|
positive_indices = torch.nonzero(label_ids.view(-1) == 1, as_tuple=False)
|
|
|
|
for i, indx in enumerate(positive_indices):
|
|
labels[i, indx.item()] = 1
|
|
return labels
|
|
|
|
def formatted_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs):
|
|
raise NotImplementedError("formatted_preds is not supported in TextSimilarityHead yet!")
|
|
|
|
|
|
def _is_json(x):
|
|
if issubclass(type(x), Path):
|
|
return True
|
|
try:
|
|
json.dumps(x)
|
|
return True
|
|
except:
|
|
return False
|