haystack/haystack/modeling/model/predictions.py
Sara Zan a59bca3661
Apply black formatting (#2115)
* Testing black on ui/

* Applying black on docstores

* Add latest docstring and tutorial changes

* Create a single GH action for Black and docs to reduce commit noise to the minimum, slightly refactor the OpenAPI action too

* Remove comments

* Relax constraints on pydoc-markdown

* Split temporary black from the docs. Pydoc-markdown was obsolete and needs a separate PR to upgrade

* Fix a couple of bugs

* Add a type: ignore that was missing somehow

* Give path to black

* Apply Black

* Apply Black

* Relocate a couple of type: ignore

* Update documentation

* Make Linux CI run after applying Black

* Triggering Black

* Apply Black

* Remove dependency, does not work well

* Remove manually double trailing commas

* Update documentation

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2022-02-03 13:43:18 +01:00

324 lines
14 KiB
Python

from typing import List, Any, Optional, Tuple, Union, Dict
import logging
from abc import ABC
logger = logging.getLogger(__name__)
class Pred(ABC):
"""
Abstract base class for predictions of every task
"""
def __init__(self, id: str, prediction: List[Any], context: str):
self.id = id
self.prediction = prediction
self.context = context
def to_json(self):
raise NotImplementedError
class QACandidate:
"""
A single QA candidate answer.
"""
def __init__(
self,
answer_type: str,
score: float,
offset_answer_start: int,
offset_answer_end: int,
offset_unit: str,
aggregation_level: str,
probability: Optional[float] = None,
n_passages_in_doc: Optional[int] = None,
passage_id: Optional[str] = None,
confidence: Optional[float] = None,
):
"""
:param answer_type: The category that this answer falls into e.g. "no_answer", "yes", "no" or "span"
:param score: The score representing the model's confidence of this answer
:param offset_answer_start: The index of the start of the answer span (whether it is char or tok is stated in self.offset_unit)
:param offset_answer_end: The index of the start of the answer span (whether it is char or tok is stated in self.offset_unit)
:param offset_unit: States whether the offsets refer to character or token indices
:param aggregation_level: States whether this candidate and its indices are on a passage level (pre aggregation) or on a document level (post aggregation)
:param probability: The probability the model assigns to the answer
:param n_passages_in_doc: Number of passages that make up the document
:param passage_id: The id of the passage which contains this candidate answer
:param confidence: The (calibrated) confidence score representing the model's predicted accuracy of the index of the start of the answer span
"""
# self.answer_type can be "no_answer", "yes", "no" or "span"
self.answer_type = answer_type
self.score = score
self.probability = probability
# If self.answer_type is "span", self.answer is a string answer (generated by self.span_to_string())
# Otherwise, it is None
self.answer = None # type: Optional[str]
self.offset_answer_start = offset_answer_start
self.offset_answer_end = offset_answer_end
# If self.answer_type is in ["yes", "no"] then self.answer_support is a text string
# If self.answer is a string answer span or self.answer_type is "no_answer", answer_support is None
self.answer_support = None # type: Optional[str]
self.offset_answer_support_start = None # type: Optional[int]
self.offset_answer_support_end = None # type: Optional[int]
# self.context is the document or passage where the answer is found
self.context_window = None # type: Optional[str]
self.offset_context_window_start = None # type: Optional[int]
self.offset_context_window_end = None # type: Optional[int]
# Offset unit is either "token" or "char"
# Aggregation level is either "doc" or "passage"
self.offset_unit = offset_unit
self.aggregation_level = aggregation_level
self.n_passages_in_doc = n_passages_in_doc
self.passage_id = passage_id
self.confidence = confidence
# This attribute is used by Haystack to store sample metadata
self.meta = None
def set_context_window(self, context_window_size: int, clear_text: str):
window_str, start_ch, end_ch = self._create_context_window(context_window_size, clear_text)
self.context_window = window_str
self.offset_context_window_start = start_ch
self.offset_context_window_end = end_ch
def set_answer_string(self, token_offsets: List[int], document_text: str):
pred_str, self.offset_answer_start, self.offset_answer_end = self._span_to_string(token_offsets, document_text)
self.offset_unit = "char"
self._add_answer(pred_str)
def _add_answer(self, string: str):
"""
Set the answer string. This method will check that the answer given is valid given the start
and end indices that are stored in the object.
"""
if string == "":
self.answer = "no_answer"
if self.offset_answer_start != 0 or self.offset_answer_end != 0:
logger.error(
f"Both start and end offsets should be 0: \n"
f"{self.offset_answer_start}, {self.offset_answer_end} with a no_answer. "
)
else:
self.answer = string
if self.offset_answer_end - self.offset_answer_start <= 0:
logger.error(
f"End offset comes before start offset: \n"
f"({self.offset_answer_start}, {self.offset_answer_end}) with a span answer. "
)
elif self.offset_answer_end <= 0:
logger.error(
f"Invalid end offset: \n"
f"({self.offset_answer_start}, {self.offset_answer_end}) with a span answer. "
)
def _create_context_window(self, context_window_size: int, clear_text: str) -> Tuple[str, int, int]:
"""
Extract from the clear_text a window that contains the answer and (usually) some amount of text on either
side of the answer. Useful for cases where the answer and its surrounding context needs to be
displayed in a UI. If the self.context_window_size is smaller than the extracted answer, it will be
enlarged so that it can contain the answer
:param context_window_size: The size of the context window to be generated. Note that the window size may be increased if the answer is longer.
:param clear_text: The text from which the answer is extracted
"""
if self.offset_answer_start == 0 and self.offset_answer_end == 0:
return "", 0, 0
else:
# If the extracted answer is longer than the context_window_size,
# we will increase the context_window_size
len_ans = self.offset_answer_end - self.offset_answer_start
context_window_size = max(context_window_size, len_ans + 1)
len_text = len(clear_text)
midpoint = int(len_ans / 2) + self.offset_answer_start
half_window = int(context_window_size / 2)
window_start_ch = midpoint - half_window
window_end_ch = midpoint + half_window
# if we have part of the context window overlapping the start or end of the passage,
# we'll trim it and use the additional chars on the other side of the answer
overhang_start = max(0, -window_start_ch)
overhang_end = max(0, window_end_ch - len_text)
window_start_ch -= overhang_end
window_start_ch = max(0, window_start_ch)
window_end_ch += overhang_start
window_end_ch = min(len_text, window_end_ch)
window_str = clear_text[window_start_ch:window_end_ch]
return window_str, window_start_ch, window_end_ch
def _span_to_string(self, token_offsets: List[int], clear_text: str) -> Tuple[str, int, int]:
"""
Generates a string answer span using self.offset_answer_start and self.offset_answer_end. If the candidate
is a no answer, an empty string is returned
:param token_offsets: A list of ints which give the start character index of the corresponding token
:param clear_text: The text from which the answer span is to be extracted
:return: The string answer span, followed by the start and end character indices
"""
if self.offset_unit != "token":
logger.error(
f"QACandidate needs to have self.offset_unit=token before calling _span_to_string() (id = {self.passage_id})"
)
start_t = self.offset_answer_start
end_t = self.offset_answer_end
# If it is a no_answer prediction
if start_t == -1 and end_t == -1:
return "", 0, 0
n_tokens = len(token_offsets)
# We do this to point to the beginning of the first token after the span instead of
# the beginning of the last token in the span
end_t += 1
# Predictions sometimes land on the very final special token of the passage. But there are no
# special tokens on the document level. We will just interpret this as a span that stretches
# to the end of the document
end_t = min(end_t, n_tokens)
start_ch = int(token_offsets[start_t])
# i.e. pointing at the END of the last token
if end_t == n_tokens:
end_ch = len(clear_text)
else:
end_ch = token_offsets[end_t]
final_text = clear_text[start_ch:end_ch]
# if the final_text is more than whitespaces we trim it otherwise return a no_answer
# final_text can be an empty string if start_t points to the very final token of the passage
# final_text can be a whitespace if there is a whitespace token in the text, e.g.,
# if the original text contained multiple consecutive whitespaces
if len(final_text.strip()) > 0:
final_text = final_text.strip()
else:
return "", 0, 0
end_ch = int(start_ch + len(final_text))
return final_text, start_ch, end_ch
def to_doc_level(self, start: int, end: int):
"""
Populate the start and end indices with document level indices. Changes aggregation level to 'document'
"""
self.offset_answer_start = start
self.offset_answer_end = end
self.aggregation_level = "document"
def to_list(self) -> List[Optional[Union[str, int, float]]]:
return [self.answer, self.offset_answer_start, self.offset_answer_end, self.score, self.passage_id]
class QAPred(Pred):
"""
A set of QA predictions for a passage or a document. The candidates are stored in QAPred.prediction which is a
list of QACandidate objects. Also contains all attributes needed to convert the object into json format and also
to create a context window for a UI
"""
def __init__(
self,
id: str,
prediction: List[QACandidate],
context: str,
question: str,
token_offsets: List[int],
context_window_size: int,
aggregation_level: str,
no_answer_gap: float,
ground_truth_answer: str = None,
answer_types: List[str] = [],
):
"""
:param id: The id of the passage or document
:param prediction: A list of QACandidate objects for the given question and document
:param context: The text passage from which the answer can be extracted
:param question: The question being posed
:param token_offsets: A list of ints indicating the start char index of each token
:param context_window_size: The number of chars in the text window around the answer
:param aggregation_level: States whether this candidate and its indices are on a passage level (pre aggregation) or on a document level (post aggregation)
:param no_answer_gap: How much the QuestionAnsweringHead.no_ans_boost needs to change to turn a no_answer to a positive answer
:param ground_truth_answer: Ground truth answers
:param answer_types: List of answer_types supported by this task e.g. ["span", "yes_no", "no_answer"]
"""
super().__init__(id, prediction, context)
self.question = question
self.token_offsets = token_offsets
self.context_window_size = context_window_size
self.aggregation_level = aggregation_level
self.answer_types = answer_types
self.ground_truth_answer = ground_truth_answer
self.no_answer_gap = no_answer_gap
self.n_passages = self.prediction[0].n_passages_in_doc
for qa_candidate in self.prediction:
qa_candidate.set_answer_string(token_offsets, self.context)
qa_candidate.set_context_window(self.context_window_size, self.context)
def to_json(self, squad=False) -> Dict:
"""
Converts the information stored in the object into a json format.
:param squad: If True, no_answers are represented by the empty string instead of "no_answer"
"""
answers = self._answers_to_json(self.id, squad)
ret = {
"task": "qa",
"predictions": [
{
"question": self.question,
"id": self.id,
"ground_truth": self.ground_truth_answer,
"answers": answers,
"no_ans_gap": self.no_answer_gap, # Add no_ans_gap to current no_ans_boost for switching top prediction
}
],
}
if squad:
del ret["predictions"][0]["id"] # type: ignore
ret["predictions"][0]["question_id"] = self.id # type: ignore
return ret
def _answers_to_json(self, ext_id, squad=False) -> List[Dict]:
"""
Convert all answers into a json format
:param id: ID of the question document pair
:param squad: If True, no_answers are represented by the empty string instead of "no_answer"
"""
ret = []
# iterate over the top_n predictions of the one document
for qa_candidate in self.prediction:
if squad and qa_candidate.answer == "no_answer":
answer_string = ""
else:
answer_string = qa_candidate.answer
curr = {
"score": qa_candidate.score,
"probability": None,
"answer": answer_string,
"offset_answer_start": qa_candidate.offset_answer_start,
"offset_answer_end": qa_candidate.offset_answer_end,
"context": qa_candidate.context_window,
"offset_context_start": qa_candidate.offset_context_window_start,
"offset_context_end": qa_candidate.offset_context_window_end,
"document_id": ext_id,
}
ret.append(curr)
return ret
def to_squad_eval(self) -> Dict:
return self.to_json(squad=True)