mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-26 17:36:34 +00:00
feat: Extractive QA answer deduplication (#6459)
* Add answer deduplication * Fix test * Handle None case * Release notes * Handle cases where documents or answer spans could be None * Adding checks for Nones and satisfying mypy * Add option to turn off deduplication * Adding unit tests * Refactored tests to use fixtures * Added overlap_threshold to run * Update test * Fixes related to the merge * Remove casting, use direct variable names * Move out if statement and add new test for it * Update if statement to match comment * Update how if statements work
This commit is contained in:
parent
c294b8ac8c
commit
dcf37c5173
@ -49,6 +49,7 @@ class ExtractiveReader:
|
|||||||
answers_per_seq: Optional[int] = None,
|
answers_per_seq: Optional[int] = None,
|
||||||
no_answer: bool = True,
|
no_answer: bool = True,
|
||||||
calibration_factor: float = 0.1,
|
calibration_factor: float = 0.1,
|
||||||
|
overlap_threshold: Optional[float] = 0.01,
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -75,6 +76,12 @@ class ExtractiveReader:
|
|||||||
:param no_answer: Whether to return an additional `no answer` with an empty text and a score representing the
|
:param no_answer: Whether to return an additional `no answer` with an empty text and a score representing the
|
||||||
probability that the other top_k answers are incorrect.
|
probability that the other top_k answers are incorrect.
|
||||||
:param calibration_factor: Factor used for calibrating probabilities.
|
:param calibration_factor: Factor used for calibrating probabilities.
|
||||||
|
:param overlap_threshold: If set this will remove duplicate answers if they have an overlap larger than the
|
||||||
|
supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
|
||||||
|
one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
|
||||||
|
However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
|
||||||
|
both of these answers could be kept if this variable is set to 0.24 or lower.
|
||||||
|
If None is provided then all answers are kept.
|
||||||
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
|
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
|
||||||
when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass,
|
when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass,
|
||||||
see the model's documentation.
|
see the model's documentation.
|
||||||
@ -93,6 +100,7 @@ class ExtractiveReader:
|
|||||||
self.no_answer = no_answer
|
self.no_answer = no_answer
|
||||||
self.calibration_factor = calibration_factor
|
self.calibration_factor = calibration_factor
|
||||||
self.model_kwargs = model_kwargs or {}
|
self.model_kwargs = model_kwargs or {}
|
||||||
|
self.overlap_threshold = overlap_threshold
|
||||||
|
|
||||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
@ -258,6 +266,7 @@ class ExtractiveReader:
|
|||||||
query_ids: List[int],
|
query_ids: List[int],
|
||||||
document_ids: List[int],
|
document_ids: List[int],
|
||||||
no_answer: bool,
|
no_answer: bool,
|
||||||
|
overlap_threshold: Optional[float],
|
||||||
) -> List[List[ExtractedAnswer]]:
|
) -> List[List[ExtractedAnswer]]:
|
||||||
"""
|
"""
|
||||||
Reconstructs the nested structure that existed before flattening. Also computes a no answer score.
|
Reconstructs the nested structure that existed before flattening. Also computes a no answer score.
|
||||||
@ -290,7 +299,8 @@ class ExtractiveReader:
|
|||||||
answer.query = queries[query_id]
|
answer.query = queries[query_id]
|
||||||
current_answers.append(answer)
|
current_answers.append(answer)
|
||||||
i += 1
|
i += 1
|
||||||
current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True)
|
current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True)
|
||||||
|
current_answers = self.deduplicate_by_overlap(current_answers, overlap_threshold=overlap_threshold)
|
||||||
current_answers = current_answers[:top_k]
|
current_answers = current_answers[:top_k]
|
||||||
if no_answer:
|
if no_answer:
|
||||||
no_answer_score = math.prod(1 - answer.score for answer in current_answers)
|
no_answer_score = math.prod(1 - answer.score for answer in current_answers)
|
||||||
@ -298,13 +308,120 @@ class ExtractiveReader:
|
|||||||
data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score
|
data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score
|
||||||
)
|
)
|
||||||
current_answers.append(answer_)
|
current_answers.append(answer_)
|
||||||
current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True)
|
current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True)
|
||||||
if score_threshold is not None:
|
if score_threshold is not None:
|
||||||
current_answers = [answer for answer in current_answers if answer.score >= score_threshold]
|
current_answers = [answer for answer in current_answers if answer.score >= score_threshold]
|
||||||
nested_answers.append(current_answers)
|
nested_answers.append(current_answers)
|
||||||
|
|
||||||
return nested_answers
|
return nested_answers
|
||||||
|
|
||||||
|
def _calculate_overlap(self, answer1_start: int, answer1_end: int, answer2_start: int, answer2_end: int) -> int:
|
||||||
|
"""
|
||||||
|
Calculates the amount of overlap (in number of characters) between two answer offsets.
|
||||||
|
|
||||||
|
Stack overflow post explaining how to calculate overlap between two ranges:
|
||||||
|
https://stackoverflow.com/questions/325933/determine-whether-two-date-ranges-overlap/325964#325964
|
||||||
|
"""
|
||||||
|
# Check for overlap: (StartA <= EndB) and (StartB <= EndA)
|
||||||
|
if answer1_start <= answer2_end and answer2_start <= answer1_end:
|
||||||
|
return min(
|
||||||
|
answer1_end - answer1_start,
|
||||||
|
answer1_end - answer2_start,
|
||||||
|
answer2_end - answer1_start,
|
||||||
|
answer2_end - answer2_start,
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _should_keep(
|
||||||
|
self, candidate_answer: ExtractedAnswer, current_answers: List[ExtractedAnswer], overlap_threshold: float
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if the answer should be kept based on how much it overlaps with previous answers.
|
||||||
|
|
||||||
|
NOTE: We might want to avoid throwing away answers that only have a few character (or word) overlap:
|
||||||
|
- E.g. The answers "the river in" and "in Maine" from the context "I want to go to the river in Maine."
|
||||||
|
might both want to be kept.
|
||||||
|
|
||||||
|
:param candidate_answer: Candidate answer that will be checked if it should be kept.
|
||||||
|
:param current_answers: Current list of answers that will be kept.
|
||||||
|
:param overlap_threshold: If the overlap between two answers is greater than this threshold then return False.
|
||||||
|
"""
|
||||||
|
keep = True
|
||||||
|
|
||||||
|
# If the candidate answer doesn't have a document keep it
|
||||||
|
if not candidate_answer.document:
|
||||||
|
return keep
|
||||||
|
|
||||||
|
for ans in current_answers:
|
||||||
|
# If an answer in current_answers doesn't have a document skip the comparison
|
||||||
|
if not ans.document:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If offset is missing then keep both
|
||||||
|
if ans.document_offset is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If offset is missing then keep both
|
||||||
|
if candidate_answer.document_offset is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If the answers come from different documents then keep both
|
||||||
|
if candidate_answer.document.id != ans.document.id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
overlap_len = self._calculate_overlap(
|
||||||
|
answer1_start=ans.document_offset.start,
|
||||||
|
answer1_end=ans.document_offset.end,
|
||||||
|
answer2_start=candidate_answer.document_offset.start,
|
||||||
|
answer2_end=candidate_answer.document_offset.end,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If overlap is 0 then keep
|
||||||
|
if overlap_len == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
overlap_frac_answer1 = overlap_len / (ans.document_offset.end - ans.document_offset.start)
|
||||||
|
overlap_frac_answer2 = overlap_len / (
|
||||||
|
candidate_answer.document_offset.end - candidate_answer.document_offset.start
|
||||||
|
)
|
||||||
|
|
||||||
|
if overlap_frac_answer1 > overlap_threshold or overlap_frac_answer2 > overlap_threshold:
|
||||||
|
keep = False
|
||||||
|
break
|
||||||
|
|
||||||
|
return keep
|
||||||
|
|
||||||
|
def deduplicate_by_overlap(
|
||||||
|
self, answers: List[ExtractedAnswer], overlap_threshold: Optional[float]
|
||||||
|
) -> List[ExtractedAnswer]:
|
||||||
|
"""
|
||||||
|
This de-duplicates overlapping Extractive Answers from the same document based on how much the spans of the
|
||||||
|
answers overlap.
|
||||||
|
|
||||||
|
:param answers: List of answers to be deduplicated.
|
||||||
|
:param overlap_threshold: If set this will remove duplicate answers if they have an overlap larger than the
|
||||||
|
supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
|
||||||
|
one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
|
||||||
|
However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
|
||||||
|
both of these answers could be kept if this variable is set to 0.24 or lower.
|
||||||
|
If None is provided then all answers are kept.
|
||||||
|
"""
|
||||||
|
if overlap_threshold is None:
|
||||||
|
return answers
|
||||||
|
|
||||||
|
# Initialize with the first answer and its offsets_in_document
|
||||||
|
deduplicated_answers = [answers[0]]
|
||||||
|
|
||||||
|
# Loop over remaining answers to check for overlaps
|
||||||
|
for ans in answers[1:]:
|
||||||
|
keep = self._should_keep(
|
||||||
|
candidate_answer=ans, current_answers=deduplicated_answers, overlap_threshold=overlap_threshold
|
||||||
|
)
|
||||||
|
if keep:
|
||||||
|
deduplicated_answers.append(ans)
|
||||||
|
|
||||||
|
return deduplicated_answers
|
||||||
|
|
||||||
@component.output_types(answers=List[ExtractedAnswer])
|
@component.output_types(answers=List[ExtractedAnswer])
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
@ -317,6 +434,7 @@ class ExtractiveReader:
|
|||||||
max_batch_size: Optional[int] = None,
|
max_batch_size: Optional[int] = None,
|
||||||
answers_per_seq: Optional[int] = None,
|
answers_per_seq: Optional[int] = None,
|
||||||
no_answer: Optional[bool] = None,
|
no_answer: Optional[bool] = None,
|
||||||
|
overlap_threshold: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Locates and extracts answers from the given Documents using the given query.
|
Locates and extracts answers from the given Documents using the given query.
|
||||||
@ -325,8 +443,6 @@ class ExtractiveReader:
|
|||||||
:param documents: List of Documents in which you want to search for an answer to the query.
|
:param documents: List of Documents in which you want to search for an answer to the query.
|
||||||
:param top_k: The maximum number of answers to return.
|
:param top_k: The maximum number of answers to return.
|
||||||
An additional answer is returned if no_answer is set to True (default).
|
An additional answer is returned if no_answer is set to True (default).
|
||||||
:param score_threshold:
|
|
||||||
:return: List of ExtractedAnswers sorted by (desc.) answer score.
|
|
||||||
:param score_threshold: Returns only answers with the score above this threshold.
|
:param score_threshold: Returns only answers with the score above this threshold.
|
||||||
:param max_seq_length: Maximum number of tokens.
|
:param max_seq_length: Maximum number of tokens.
|
||||||
If a sequence exceeds it, the sequence is split.
|
If a sequence exceeds it, the sequence is split.
|
||||||
@ -336,7 +452,17 @@ class ExtractiveReader:
|
|||||||
:param max_batch_size: Maximum number of samples that are fed through the model at the same time.
|
:param max_batch_size: Maximum number of samples that are fed through the model at the same time.
|
||||||
:param answers_per_seq: Number of answer candidates to consider per sequence.
|
:param answers_per_seq: Number of answer candidates to consider per sequence.
|
||||||
This is relevant when a Document was split into multiple sequences because of max_seq_length.
|
This is relevant when a Document was split into multiple sequences because of max_seq_length.
|
||||||
|
Default: 20
|
||||||
:param no_answer: Whether to return no answer scores.
|
:param no_answer: Whether to return no answer scores.
|
||||||
|
Default: True
|
||||||
|
:param overlap_threshold: If set this will remove duplicate answers if they have an overlap larger than the
|
||||||
|
supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
|
||||||
|
one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
|
||||||
|
However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
|
||||||
|
both of these answers could be kept if this variable is set to 0.24 or lower.
|
||||||
|
If None is provided then all answers are kept.
|
||||||
|
Default: 0.01
|
||||||
|
:return: List of ExtractedAnswers sorted by (desc.) answer score.
|
||||||
"""
|
"""
|
||||||
queries = [query] # Temporary solution until we have decided what batching should look like in v2
|
queries = [query] # Temporary solution until we have decided what batching should look like in v2
|
||||||
nested_documents = [documents]
|
nested_documents = [documents]
|
||||||
@ -348,8 +474,9 @@ class ExtractiveReader:
|
|||||||
max_seq_length = max_seq_length or self.max_seq_length
|
max_seq_length = max_seq_length or self.max_seq_length
|
||||||
stride = stride or self.stride
|
stride = stride or self.stride
|
||||||
max_batch_size = max_batch_size or self.max_batch_size
|
max_batch_size = max_batch_size or self.max_batch_size
|
||||||
answers_per_seq = answers_per_seq or self.answers_per_seq or top_k or 20
|
answers_per_seq = answers_per_seq or self.answers_per_seq or 20
|
||||||
no_answer = no_answer if no_answer is not None else self.no_answer
|
no_answer = no_answer if no_answer is not None else self.no_answer
|
||||||
|
overlap_threshold = overlap_threshold or self.overlap_threshold
|
||||||
|
|
||||||
flattened_queries, flattened_documents, query_ids = self._flatten_documents(queries, nested_documents)
|
flattened_queries, flattened_documents, query_ids = self._flatten_documents(queries, nested_documents)
|
||||||
input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess(
|
input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess(
|
||||||
@ -385,17 +512,18 @@ class ExtractiveReader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
answers = self._nest_answers(
|
answers = self._nest_answers(
|
||||||
start,
|
start=start,
|
||||||
end,
|
end=end,
|
||||||
probabilities,
|
probabilities=probabilities,
|
||||||
flattened_documents,
|
flattened_documents=flattened_documents,
|
||||||
queries,
|
queries=queries,
|
||||||
answers_per_seq,
|
answers_per_seq=answers_per_seq,
|
||||||
top_k,
|
top_k=top_k,
|
||||||
score_threshold,
|
score_threshold=score_threshold,
|
||||||
query_ids,
|
query_ids=query_ids,
|
||||||
document_ids,
|
document_ids=document_ids,
|
||||||
no_answer,
|
no_answer=no_answer,
|
||||||
|
overlap_threshold=overlap_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"answers": answers[0]} # same temporary batching fix as above
|
return {"answers": answers[0]} # same temporary batching fix as above
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Introduces answer deduplication on the Document level based on an overlap threshold.
|
@ -7,7 +7,7 @@ import torch
|
|||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
|
||||||
from haystack.components.readers import ExtractiveReader
|
from haystack.components.readers import ExtractiveReader
|
||||||
from haystack import Document
|
from haystack import Document, ExtractedAnswer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -233,8 +233,19 @@ def test_nest_answers(mock_reader: ExtractiveReader):
|
|||||||
probabilities = torch.arange(5).unsqueeze(0) / 5 + torch.arange(6).unsqueeze(-1) / 25
|
probabilities = torch.arange(5).unsqueeze(0) / 5 + torch.arange(6).unsqueeze(-1) / 25
|
||||||
query_ids = [0] * 3 + [1] * 3
|
query_ids = [0] * 3 + [1] * 3
|
||||||
document_ids = list(range(3)) * 2
|
document_ids = list(range(3)) * 2
|
||||||
nested_answers = mock_reader._nest_answers(
|
nested_answers = mock_reader._nest_answers( # type: ignore
|
||||||
start, end, probabilities, example_documents[0], example_queries, 5, 3, None, query_ids, document_ids, True # type: ignore
|
start=start,
|
||||||
|
end=end,
|
||||||
|
probabilities=probabilities,
|
||||||
|
flattened_documents=example_documents[0],
|
||||||
|
queries=example_queries,
|
||||||
|
answers_per_seq=5,
|
||||||
|
top_k=3,
|
||||||
|
score_threshold=None,
|
||||||
|
query_ids=query_ids,
|
||||||
|
document_ids=document_ids,
|
||||||
|
no_answer=True,
|
||||||
|
overlap_threshold=None,
|
||||||
)
|
)
|
||||||
expected_no_answers = [0.2 * 0.16 * 0.12, 0]
|
expected_no_answers = [0.2 * 0.16 * 0.12, 0]
|
||||||
for query, answers, expected_no_answer, probabilities in zip(
|
for query, answers, expected_no_answer, probabilities in zip(
|
||||||
@ -261,6 +272,196 @@ def test_warm_up_use_hf_token(mocked_automodel, mocked_autotokenizer):
|
|||||||
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token")
|
mocked_autotokenizer.assert_called_once_with("deepset/roberta-base-squad2", token="fake-token")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeduplication:
|
||||||
|
@pytest.fixture
|
||||||
|
def doc1(self):
|
||||||
|
return Document(content="I want to go to the river in Maine.")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def doc2(self):
|
||||||
|
return Document(content="I want to go skiing in Colorado.")
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def candidate_answer(self, doc1):
|
||||||
|
answer1 = "the river"
|
||||||
|
return ExtractedAnswer(
|
||||||
|
query="test",
|
||||||
|
data=answer1,
|
||||||
|
document=doc1,
|
||||||
|
document_offset=ExtractedAnswer.Span(doc1.content.find(answer1), doc1.content.find(answer1) + len(answer1)),
|
||||||
|
score=0.1,
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_calculate_overlap(self, mock_reader: ExtractiveReader, doc1: Document):
|
||||||
|
answer1 = "the river"
|
||||||
|
answer2 = "river in Maine"
|
||||||
|
overlap_in_characters = mock_reader._calculate_overlap(
|
||||||
|
answer1_start=doc1.content.find(answer1),
|
||||||
|
answer1_end=doc1.content.find(answer1) + len(answer1),
|
||||||
|
answer2_start=doc1.content.find(answer2),
|
||||||
|
answer2_end=doc1.content.find(answer2) + len(answer2),
|
||||||
|
)
|
||||||
|
assert overlap_in_characters == 5
|
||||||
|
|
||||||
|
def test_should_keep_false(
|
||||||
|
self, mock_reader: ExtractiveReader, doc1: Document, doc2: Document, candidate_answer: ExtractedAnswer
|
||||||
|
):
|
||||||
|
answer2 = "river in Maine"
|
||||||
|
answer3 = "skiing in Colorado"
|
||||||
|
keep = mock_reader._should_keep(
|
||||||
|
candidate_answer=candidate_answer,
|
||||||
|
current_answers=[
|
||||||
|
ExtractedAnswer(
|
||||||
|
query="test",
|
||||||
|
data=answer2,
|
||||||
|
document=doc1,
|
||||||
|
document_offset=ExtractedAnswer.Span(
|
||||||
|
doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)
|
||||||
|
),
|
||||||
|
score=0.1,
|
||||||
|
meta={},
|
||||||
|
),
|
||||||
|
ExtractedAnswer(
|
||||||
|
query="test",
|
||||||
|
data=answer3,
|
||||||
|
document=doc2,
|
||||||
|
document_offset=ExtractedAnswer.Span(
|
||||||
|
doc2.content.find(answer3), doc2.content.find(answer3) + len(answer3)
|
||||||
|
),
|
||||||
|
score=0.1,
|
||||||
|
meta={},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
overlap_threshold=0.01,
|
||||||
|
)
|
||||||
|
assert keep is False
|
||||||
|
|
||||||
|
def test_should_keep_true(
|
||||||
|
self, mock_reader: ExtractiveReader, doc1: Document, doc2: Document, candidate_answer: ExtractedAnswer
|
||||||
|
):
|
||||||
|
answer2 = "Maine"
|
||||||
|
answer3 = "skiing in Colorado"
|
||||||
|
keep = mock_reader._should_keep(
|
||||||
|
candidate_answer=candidate_answer,
|
||||||
|
current_answers=[
|
||||||
|
ExtractedAnswer(
|
||||||
|
query="test",
|
||||||
|
data=answer2,
|
||||||
|
document=doc1,
|
||||||
|
document_offset=ExtractedAnswer.Span(
|
||||||
|
doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)
|
||||||
|
),
|
||||||
|
score=0.1,
|
||||||
|
meta={},
|
||||||
|
),
|
||||||
|
ExtractedAnswer(
|
||||||
|
query="test",
|
||||||
|
data=answer3,
|
||||||
|
document=doc2,
|
||||||
|
document_offset=ExtractedAnswer.Span(
|
||||||
|
doc2.content.find(answer3), doc2.content.find(answer3) + len(answer3)
|
||||||
|
),
|
||||||
|
score=0.1,
|
||||||
|
meta={},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
overlap_threshold=0.01,
|
||||||
|
)
|
||||||
|
assert keep is True
|
||||||
|
|
||||||
|
def test_should_keep_missing_document_current_answer(
|
||||||
|
self, mock_reader: ExtractiveReader, doc1: Document, candidate_answer: ExtractedAnswer
|
||||||
|
):
|
||||||
|
answer2 = "river in Maine"
|
||||||
|
keep = mock_reader._should_keep(
|
||||||
|
candidate_answer=candidate_answer,
|
||||||
|
current_answers=[
|
||||||
|
ExtractedAnswer(
|
||||||
|
query="test",
|
||||||
|
data=answer2,
|
||||||
|
document=None,
|
||||||
|
document_offset=ExtractedAnswer.Span(
|
||||||
|
doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)
|
||||||
|
),
|
||||||
|
score=0.1,
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
overlap_threshold=0.01,
|
||||||
|
)
|
||||||
|
assert keep is True
|
||||||
|
|
||||||
|
def test_should_keep_missing_document_candidate_answer(
|
||||||
|
self, mock_reader: ExtractiveReader, doc1: Document, candidate_answer: ExtractedAnswer
|
||||||
|
):
|
||||||
|
answer2 = "river in Maine"
|
||||||
|
keep = mock_reader._should_keep(
|
||||||
|
candidate_answer=ExtractedAnswer(
|
||||||
|
query="test",
|
||||||
|
data=answer2,
|
||||||
|
document=None,
|
||||||
|
document_offset=ExtractedAnswer.Span(
|
||||||
|
doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)
|
||||||
|
),
|
||||||
|
score=0.1,
|
||||||
|
meta={},
|
||||||
|
),
|
||||||
|
current_answers=[
|
||||||
|
ExtractedAnswer(
|
||||||
|
query="test",
|
||||||
|
data=answer2,
|
||||||
|
document=doc1,
|
||||||
|
document_offset=ExtractedAnswer.Span(
|
||||||
|
doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)
|
||||||
|
),
|
||||||
|
score=0.1,
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
overlap_threshold=0.01,
|
||||||
|
)
|
||||||
|
assert keep is True
|
||||||
|
|
||||||
|
def test_should_keep_missing_span(
|
||||||
|
self, mock_reader: ExtractiveReader, doc1: Document, candidate_answer: ExtractedAnswer
|
||||||
|
):
|
||||||
|
answer2 = "river in Maine"
|
||||||
|
keep = mock_reader._should_keep(
|
||||||
|
candidate_answer=candidate_answer,
|
||||||
|
current_answers=[
|
||||||
|
ExtractedAnswer(query="test", data=answer2, document=doc1, document_offset=None, score=0.1, meta={})
|
||||||
|
],
|
||||||
|
overlap_threshold=0.01,
|
||||||
|
)
|
||||||
|
assert keep is True
|
||||||
|
|
||||||
|
def test_deduplicate_by_overlap_none_overlap(
|
||||||
|
self, mock_reader: ExtractiveReader, candidate_answer: ExtractedAnswer
|
||||||
|
):
|
||||||
|
result = mock_reader.deduplicate_by_overlap(
|
||||||
|
answers=[candidate_answer, candidate_answer], overlap_threshold=None
|
||||||
|
)
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
def test_deduplicate_by_overlap(
|
||||||
|
self, mock_reader: ExtractiveReader, candidate_answer: ExtractedAnswer, doc1: Document
|
||||||
|
):
|
||||||
|
answer2 = "Maine"
|
||||||
|
extracted_answer2 = ExtractedAnswer(
|
||||||
|
query="test",
|
||||||
|
data=answer2,
|
||||||
|
document=doc1,
|
||||||
|
document_offset=ExtractedAnswer.Span(doc1.content.find(answer2), doc1.content.find(answer2) + len(answer2)),
|
||||||
|
score=0.1,
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
result = mock_reader.deduplicate_by_overlap(
|
||||||
|
answers=[candidate_answer, candidate_answer, extracted_answer2], overlap_threshold=0.01
|
||||||
|
)
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_t5():
|
def test_t5():
|
||||||
reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad")
|
reader = ExtractiveReader("TARUNBHATT/flan-t5-small-finetuned-squad")
|
||||||
@ -274,6 +475,7 @@ def test_t5():
|
|||||||
assert answers[1].score == pytest.approx(0.7703777551651001)
|
assert answers[1].score == pytest.approx(0.7703777551651001)
|
||||||
assert answers[2].data is None
|
assert answers[2].data is None
|
||||||
assert answers[2].score == pytest.approx(0.051331606147570596)
|
assert answers[2].score == pytest.approx(0.051331606147570596)
|
||||||
|
assert len(answers) == 3
|
||||||
# Uncomment assertions below when batching is reintroduced
|
# Uncomment assertions below when batching is reintroduced
|
||||||
# assert answers[0][2].score == pytest.approx(0.051331606147570596)
|
# assert answers[0][2].score == pytest.approx(0.051331606147570596)
|
||||||
# assert answers[1][0].data == "Jerry"
|
# assert answers[1][0].data == "Jerry"
|
||||||
@ -297,6 +499,7 @@ def test_roberta():
|
|||||||
assert answers[1].score == pytest.approx(0.857952892780304)
|
assert answers[1].score == pytest.approx(0.857952892780304)
|
||||||
assert answers[2].data is None
|
assert answers[2].data is None
|
||||||
assert answers[2].score == pytest.approx(0.019673851661650588)
|
assert answers[2].score == pytest.approx(0.019673851661650588)
|
||||||
|
assert len(answers) == 3
|
||||||
# uncomment assertions below when there is batching in v2
|
# uncomment assertions below when there is batching in v2
|
||||||
# assert answers[0][0].data == "Olaf Scholz"
|
# assert answers[0][0].data == "Olaf Scholz"
|
||||||
# assert answers[0][0].score == pytest.approx(0.8614975214004517)
|
# assert answers[0][0].score == pytest.approx(0.8614975214004517)
|
||||||
@ -314,7 +517,7 @@ def test_roberta():
|
|||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_matches_hf_pipeline():
|
def test_matches_hf_pipeline():
|
||||||
reader = ExtractiveReader("deepset/tinyroberta-squad2", device="cpu")
|
reader = ExtractiveReader("deepset/tinyroberta-squad2", device="cpu", overlap_threshold=None)
|
||||||
reader.warm_up()
|
reader.warm_up()
|
||||||
answers = reader.run(example_queries[0], [[example_documents[0][0]]][0], top_k=20, no_answer=False)[
|
answers = reader.run(example_queries[0], [[example_documents[0][0]]][0], top_k=20, no_answer=False)[
|
||||||
"answers"
|
"answers"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user