mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-04 22:04:01 +00:00

* Testing black on ui/ * Applying black on docstores * Add latest docstring and tutorial changes * Create a single GH action for Black and docs to reduce commit noise to the minimum, slightly refactor the OpenAPI action too * Remove comments * Relax constraints on pydoc-markdown * Split temporary black from the docs. Pydoc-markdown was obsolete and needs a separate PR to upgrade * Fix a couple of bugs * Add a type: ignore that was missing somehow * Give path to black * Apply Black * Apply Black * Relocate a couple of type: ignore * Update documentation * Make Linux CI run after applying Black * Triggering Black * Apply Black * Remove dependency, does not work well * Remove manually double trailing commas * Update documentation Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
324 lines
14 KiB
Python
324 lines
14 KiB
Python
from typing import List, Any, Optional, Tuple, Union, Dict
|
|
|
|
import logging
|
|
from abc import ABC
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Pred(ABC):
|
|
"""
|
|
Abstract base class for predictions of every task
|
|
"""
|
|
|
|
def __init__(self, id: str, prediction: List[Any], context: str):
|
|
self.id = id
|
|
self.prediction = prediction
|
|
self.context = context
|
|
|
|
def to_json(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class QACandidate:
|
|
"""
|
|
A single QA candidate answer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
answer_type: str,
|
|
score: float,
|
|
offset_answer_start: int,
|
|
offset_answer_end: int,
|
|
offset_unit: str,
|
|
aggregation_level: str,
|
|
probability: Optional[float] = None,
|
|
n_passages_in_doc: Optional[int] = None,
|
|
passage_id: Optional[str] = None,
|
|
confidence: Optional[float] = None,
|
|
):
|
|
"""
|
|
:param answer_type: The category that this answer falls into e.g. "no_answer", "yes", "no" or "span"
|
|
:param score: The score representing the model's confidence of this answer
|
|
:param offset_answer_start: The index of the start of the answer span (whether it is char or tok is stated in self.offset_unit)
|
|
:param offset_answer_end: The index of the start of the answer span (whether it is char or tok is stated in self.offset_unit)
|
|
:param offset_unit: States whether the offsets refer to character or token indices
|
|
:param aggregation_level: States whether this candidate and its indices are on a passage level (pre aggregation) or on a document level (post aggregation)
|
|
:param probability: The probability the model assigns to the answer
|
|
:param n_passages_in_doc: Number of passages that make up the document
|
|
:param passage_id: The id of the passage which contains this candidate answer
|
|
:param confidence: The (calibrated) confidence score representing the model's predicted accuracy of the index of the start of the answer span
|
|
"""
|
|
# self.answer_type can be "no_answer", "yes", "no" or "span"
|
|
self.answer_type = answer_type
|
|
self.score = score
|
|
self.probability = probability
|
|
|
|
# If self.answer_type is "span", self.answer is a string answer (generated by self.span_to_string())
|
|
# Otherwise, it is None
|
|
self.answer = None # type: Optional[str]
|
|
self.offset_answer_start = offset_answer_start
|
|
self.offset_answer_end = offset_answer_end
|
|
|
|
# If self.answer_type is in ["yes", "no"] then self.answer_support is a text string
|
|
# If self.answer is a string answer span or self.answer_type is "no_answer", answer_support is None
|
|
self.answer_support = None # type: Optional[str]
|
|
self.offset_answer_support_start = None # type: Optional[int]
|
|
self.offset_answer_support_end = None # type: Optional[int]
|
|
|
|
# self.context is the document or passage where the answer is found
|
|
self.context_window = None # type: Optional[str]
|
|
self.offset_context_window_start = None # type: Optional[int]
|
|
self.offset_context_window_end = None # type: Optional[int]
|
|
|
|
# Offset unit is either "token" or "char"
|
|
# Aggregation level is either "doc" or "passage"
|
|
self.offset_unit = offset_unit
|
|
self.aggregation_level = aggregation_level
|
|
|
|
self.n_passages_in_doc = n_passages_in_doc
|
|
self.passage_id = passage_id
|
|
self.confidence = confidence
|
|
|
|
# This attribute is used by Haystack to store sample metadata
|
|
self.meta = None
|
|
|
|
def set_context_window(self, context_window_size: int, clear_text: str):
|
|
window_str, start_ch, end_ch = self._create_context_window(context_window_size, clear_text)
|
|
self.context_window = window_str
|
|
self.offset_context_window_start = start_ch
|
|
self.offset_context_window_end = end_ch
|
|
|
|
def set_answer_string(self, token_offsets: List[int], document_text: str):
|
|
pred_str, self.offset_answer_start, self.offset_answer_end = self._span_to_string(token_offsets, document_text)
|
|
self.offset_unit = "char"
|
|
self._add_answer(pred_str)
|
|
|
|
def _add_answer(self, string: str):
|
|
"""
|
|
Set the answer string. This method will check that the answer given is valid given the start
|
|
and end indices that are stored in the object.
|
|
"""
|
|
if string == "":
|
|
self.answer = "no_answer"
|
|
if self.offset_answer_start != 0 or self.offset_answer_end != 0:
|
|
logger.error(
|
|
f"Both start and end offsets should be 0: \n"
|
|
f"{self.offset_answer_start}, {self.offset_answer_end} with a no_answer. "
|
|
)
|
|
else:
|
|
self.answer = string
|
|
if self.offset_answer_end - self.offset_answer_start <= 0:
|
|
logger.error(
|
|
f"End offset comes before start offset: \n"
|
|
f"({self.offset_answer_start}, {self.offset_answer_end}) with a span answer. "
|
|
)
|
|
elif self.offset_answer_end <= 0:
|
|
logger.error(
|
|
f"Invalid end offset: \n"
|
|
f"({self.offset_answer_start}, {self.offset_answer_end}) with a span answer. "
|
|
)
|
|
|
|
def _create_context_window(self, context_window_size: int, clear_text: str) -> Tuple[str, int, int]:
|
|
"""
|
|
Extract from the clear_text a window that contains the answer and (usually) some amount of text on either
|
|
side of the answer. Useful for cases where the answer and its surrounding context needs to be
|
|
displayed in a UI. If the self.context_window_size is smaller than the extracted answer, it will be
|
|
enlarged so that it can contain the answer
|
|
|
|
:param context_window_size: The size of the context window to be generated. Note that the window size may be increased if the answer is longer.
|
|
:param clear_text: The text from which the answer is extracted
|
|
"""
|
|
if self.offset_answer_start == 0 and self.offset_answer_end == 0:
|
|
return "", 0, 0
|
|
else:
|
|
# If the extracted answer is longer than the context_window_size,
|
|
# we will increase the context_window_size
|
|
len_ans = self.offset_answer_end - self.offset_answer_start
|
|
context_window_size = max(context_window_size, len_ans + 1)
|
|
|
|
len_text = len(clear_text)
|
|
midpoint = int(len_ans / 2) + self.offset_answer_start
|
|
half_window = int(context_window_size / 2)
|
|
window_start_ch = midpoint - half_window
|
|
window_end_ch = midpoint + half_window
|
|
|
|
# if we have part of the context window overlapping the start or end of the passage,
|
|
# we'll trim it and use the additional chars on the other side of the answer
|
|
overhang_start = max(0, -window_start_ch)
|
|
overhang_end = max(0, window_end_ch - len_text)
|
|
window_start_ch -= overhang_end
|
|
window_start_ch = max(0, window_start_ch)
|
|
window_end_ch += overhang_start
|
|
window_end_ch = min(len_text, window_end_ch)
|
|
window_str = clear_text[window_start_ch:window_end_ch]
|
|
return window_str, window_start_ch, window_end_ch
|
|
|
|
def _span_to_string(self, token_offsets: List[int], clear_text: str) -> Tuple[str, int, int]:
|
|
"""
|
|
Generates a string answer span using self.offset_answer_start and self.offset_answer_end. If the candidate
|
|
is a no answer, an empty string is returned
|
|
|
|
:param token_offsets: A list of ints which give the start character index of the corresponding token
|
|
:param clear_text: The text from which the answer span is to be extracted
|
|
:return: The string answer span, followed by the start and end character indices
|
|
"""
|
|
if self.offset_unit != "token":
|
|
logger.error(
|
|
f"QACandidate needs to have self.offset_unit=token before calling _span_to_string() (id = {self.passage_id})"
|
|
)
|
|
|
|
start_t = self.offset_answer_start
|
|
end_t = self.offset_answer_end
|
|
|
|
# If it is a no_answer prediction
|
|
if start_t == -1 and end_t == -1:
|
|
return "", 0, 0
|
|
|
|
n_tokens = len(token_offsets)
|
|
|
|
# We do this to point to the beginning of the first token after the span instead of
|
|
# the beginning of the last token in the span
|
|
end_t += 1
|
|
|
|
# Predictions sometimes land on the very final special token of the passage. But there are no
|
|
# special tokens on the document level. We will just interpret this as a span that stretches
|
|
# to the end of the document
|
|
end_t = min(end_t, n_tokens)
|
|
|
|
start_ch = int(token_offsets[start_t])
|
|
# i.e. pointing at the END of the last token
|
|
if end_t == n_tokens:
|
|
end_ch = len(clear_text)
|
|
else:
|
|
end_ch = token_offsets[end_t]
|
|
|
|
final_text = clear_text[start_ch:end_ch]
|
|
|
|
# if the final_text is more than whitespaces we trim it otherwise return a no_answer
|
|
# final_text can be an empty string if start_t points to the very final token of the passage
|
|
# final_text can be a whitespace if there is a whitespace token in the text, e.g.,
|
|
# if the original text contained multiple consecutive whitespaces
|
|
if len(final_text.strip()) > 0:
|
|
final_text = final_text.strip()
|
|
else:
|
|
return "", 0, 0
|
|
end_ch = int(start_ch + len(final_text))
|
|
|
|
return final_text, start_ch, end_ch
|
|
|
|
def to_doc_level(self, start: int, end: int):
|
|
"""
|
|
Populate the start and end indices with document level indices. Changes aggregation level to 'document'
|
|
"""
|
|
self.offset_answer_start = start
|
|
self.offset_answer_end = end
|
|
self.aggregation_level = "document"
|
|
|
|
def to_list(self) -> List[Optional[Union[str, int, float]]]:
|
|
return [self.answer, self.offset_answer_start, self.offset_answer_end, self.score, self.passage_id]
|
|
|
|
|
|
class QAPred(Pred):
|
|
"""
|
|
A set of QA predictions for a passage or a document. The candidates are stored in QAPred.prediction which is a
|
|
list of QACandidate objects. Also contains all attributes needed to convert the object into json format and also
|
|
to create a context window for a UI
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
id: str,
|
|
prediction: List[QACandidate],
|
|
context: str,
|
|
question: str,
|
|
token_offsets: List[int],
|
|
context_window_size: int,
|
|
aggregation_level: str,
|
|
no_answer_gap: float,
|
|
ground_truth_answer: str = None,
|
|
answer_types: List[str] = [],
|
|
):
|
|
"""
|
|
:param id: The id of the passage or document
|
|
:param prediction: A list of QACandidate objects for the given question and document
|
|
:param context: The text passage from which the answer can be extracted
|
|
:param question: The question being posed
|
|
:param token_offsets: A list of ints indicating the start char index of each token
|
|
:param context_window_size: The number of chars in the text window around the answer
|
|
:param aggregation_level: States whether this candidate and its indices are on a passage level (pre aggregation) or on a document level (post aggregation)
|
|
:param no_answer_gap: How much the QuestionAnsweringHead.no_ans_boost needs to change to turn a no_answer to a positive answer
|
|
:param ground_truth_answer: Ground truth answers
|
|
:param answer_types: List of answer_types supported by this task e.g. ["span", "yes_no", "no_answer"]
|
|
"""
|
|
super().__init__(id, prediction, context)
|
|
self.question = question
|
|
self.token_offsets = token_offsets
|
|
self.context_window_size = context_window_size
|
|
self.aggregation_level = aggregation_level
|
|
self.answer_types = answer_types
|
|
self.ground_truth_answer = ground_truth_answer
|
|
self.no_answer_gap = no_answer_gap
|
|
self.n_passages = self.prediction[0].n_passages_in_doc
|
|
for qa_candidate in self.prediction:
|
|
qa_candidate.set_answer_string(token_offsets, self.context)
|
|
qa_candidate.set_context_window(self.context_window_size, self.context)
|
|
|
|
def to_json(self, squad=False) -> Dict:
|
|
"""
|
|
Converts the information stored in the object into a json format.
|
|
|
|
:param squad: If True, no_answers are represented by the empty string instead of "no_answer"
|
|
"""
|
|
answers = self._answers_to_json(self.id, squad)
|
|
ret = {
|
|
"task": "qa",
|
|
"predictions": [
|
|
{
|
|
"question": self.question,
|
|
"id": self.id,
|
|
"ground_truth": self.ground_truth_answer,
|
|
"answers": answers,
|
|
"no_ans_gap": self.no_answer_gap, # Add no_ans_gap to current no_ans_boost for switching top prediction
|
|
}
|
|
],
|
|
}
|
|
if squad:
|
|
del ret["predictions"][0]["id"] # type: ignore
|
|
ret["predictions"][0]["question_id"] = self.id # type: ignore
|
|
return ret
|
|
|
|
def _answers_to_json(self, ext_id, squad=False) -> List[Dict]:
|
|
"""
|
|
Convert all answers into a json format
|
|
|
|
:param id: ID of the question document pair
|
|
:param squad: If True, no_answers are represented by the empty string instead of "no_answer"
|
|
"""
|
|
ret = []
|
|
|
|
# iterate over the top_n predictions of the one document
|
|
for qa_candidate in self.prediction:
|
|
if squad and qa_candidate.answer == "no_answer":
|
|
answer_string = ""
|
|
else:
|
|
answer_string = qa_candidate.answer
|
|
curr = {
|
|
"score": qa_candidate.score,
|
|
"probability": None,
|
|
"answer": answer_string,
|
|
"offset_answer_start": qa_candidate.offset_answer_start,
|
|
"offset_answer_end": qa_candidate.offset_answer_end,
|
|
"context": qa_candidate.context_window,
|
|
"offset_context_start": qa_candidate.offset_context_window_start,
|
|
"offset_context_end": qa_candidate.offset_context_window_end,
|
|
"document_id": ext_id,
|
|
}
|
|
ret.append(curr)
|
|
return ret
|
|
|
|
def to_squad_eval(self) -> Dict:
|
|
return self.to_json(squad=True)
|