mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 20:17:14 +00:00
Implement Context Matching (#2293)
* first context_matching impl * Update Documentation & Code Style * sort matches * fix matching bugs * Update Documentation & Code Style * add match_contexts * min_words added * Update Documentation & Code Style * rename matching.py to context_matching.py * fix mypy * added tests and heuristic for one-sided overlaps * Update Documentation & Code Style * add another noise test * Update Documentation & Code Style * improve boosting split overlaps * add non parallel versions of match_context and match_contexts * Update Documentation & Code Style * fix pylint finding * add tests for match_context and match_contexts * fix typo Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
46fa166c36
commit
e13df4b22b
@ -18,3 +18,4 @@ from haystack.utils.export_utils import (
|
||||
convert_labels_to_squad,
|
||||
)
|
||||
from haystack.utils.squad_data import SquadData
|
||||
from haystack.utils.context_matching import calculate_context_similarity, match_context, match_contexts
|
||||
|
||||
213
haystack/utils/context_matching.py
Normal file
213
haystack/utils/context_matching.py
Normal file
@ -0,0 +1,213 @@
|
||||
from collections import namedtuple
|
||||
import multiprocessing
|
||||
from typing import Generator, Iterable, Optional, Tuple, List, Union
|
||||
import re
|
||||
from rapidfuzz import fuzz
|
||||
from multiprocessing import Pool
|
||||
from tqdm import tqdm
|
||||
from itertools import groupby
|
||||
|
||||
|
||||
_CandidateScore = namedtuple("_CandidateScore", ["context_id", "candidate_id", "score"])
|
||||
|
||||
|
||||
def _score_candidate(args: Tuple[Union[str, Tuple[object, str]], Tuple[object, str], int, bool]):
|
||||
context, candidate, min_length, boost_split_overlaps = args
|
||||
candidate_id, candidate_text = candidate
|
||||
context_id, context_text = (None, context) if isinstance(context, str) else context
|
||||
score = calculate_context_similarity(
|
||||
context=context_text, candidate=candidate_text, min_length=min_length, boost_split_overlaps=boost_split_overlaps
|
||||
)
|
||||
return _CandidateScore(context_id=context_id, candidate_id=candidate_id, score=score)
|
||||
|
||||
|
||||
def normalize_white_space_and_case(str: str) -> str:
|
||||
return re.sub(r"\s+", " ", str).lower().strip()
|
||||
|
||||
|
||||
def _no_processor(str: str) -> str:
|
||||
return str
|
||||
|
||||
|
||||
def calculate_context_similarity(
|
||||
context: str, candidate: str, min_length: int = 100, boost_split_overlaps: bool = True
|
||||
) -> float:
|
||||
"""
|
||||
Calculates the text similarity score of context and candidate.
|
||||
The score's value ranges between 0.0 and 100.0.
|
||||
|
||||
:param context: The context to match.
|
||||
:param candidate: The candidate to match the context.
|
||||
:param min_length: The minimum string length context and candidate need to have in order to be scored.
|
||||
Returns 0.0 otherwise.
|
||||
:param boost_split_overlaps: Whether to boost split overlaps (e.g. [AB] <-> [BC]) that result from different preprocessing params.
|
||||
If we detect that the score is near a half match and the matching part of the candidate is at its boundaries
|
||||
we cut the context on the same side, recalculate the score and take the mean of both.
|
||||
Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total.
|
||||
"""
|
||||
# we need to handle short contexts/contents (e.g single word)
|
||||
# as they produce high scores by matching if the chars of the word are contained in the other one
|
||||
# this has to be done after normalizing
|
||||
context = normalize_white_space_and_case(context)
|
||||
candidate = normalize_white_space_and_case(candidate)
|
||||
context_len = len(context)
|
||||
candidate_len = len(candidate)
|
||||
if candidate_len < min_length or context_len < min_length:
|
||||
return 0.0
|
||||
|
||||
if context_len < candidate_len:
|
||||
shorter = context
|
||||
longer = candidate
|
||||
shorter_len = context_len
|
||||
longer_len = candidate_len
|
||||
else:
|
||||
shorter = candidate
|
||||
longer = context
|
||||
shorter_len = candidate_len
|
||||
longer_len = context_len
|
||||
|
||||
score_alignment = fuzz.partial_ratio_alignment(shorter, longer, processor=_no_processor)
|
||||
score = score_alignment.score
|
||||
|
||||
# Special handling for split overlaps (e.g. [AB] <-> [BC]):
|
||||
# If we detect that the score is near a half match and the best fitting part of longer is at its boundaries
|
||||
# we cut the shorter on the same side, recalculate the score and take the mean of both.
|
||||
# Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total
|
||||
if boost_split_overlaps and 40 <= score < 65:
|
||||
cut_shorter_left = score_alignment.dest_start == 0
|
||||
cut_shorter_right = score_alignment.dest_end == longer_len
|
||||
cut_len = shorter_len // 2
|
||||
|
||||
if cut_shorter_left:
|
||||
cut_score = fuzz.partial_ratio(shorter[cut_len:], longer, processor=_no_processor)
|
||||
if cut_score > score:
|
||||
score = (score + cut_score) / 2
|
||||
if cut_shorter_right:
|
||||
cut_score = fuzz.partial_ratio(shorter[:-cut_len], longer, processor=_no_processor)
|
||||
if cut_score > score:
|
||||
score = (score + cut_score) / 2
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def match_context(
|
||||
context: str,
|
||||
candidates: Generator[Tuple[str, str], None, None],
|
||||
threshold: float = 65.0,
|
||||
show_progress: bool = False,
|
||||
num_processes: int = None,
|
||||
chunksize: int = 1,
|
||||
min_length: int = 100,
|
||||
boost_split_overlaps: bool = True,
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Matches the context against multiple candidates. Candidates consist of a tuple of an id and its text.
|
||||
|
||||
Returns a sorted list of the candidate ids and its scores filtered by the threshold in descending order.
|
||||
|
||||
:param context: The context to match.
|
||||
:param candidates: The candidates to match the context.
|
||||
A candidate consists of a tuple of candidate id and candidate text.
|
||||
:param threshold: Score threshold that candidates must surpass to be included into the result list.
|
||||
:param show_progress: Whether to show the progress of matching all candidates.
|
||||
:param num_processes: The number of processes to be used for matching in parallel.
|
||||
:param chunksize: The chunksize used during parallel processing.
|
||||
If not specified chunksize is 1.
|
||||
For very long iterables using a large value for chunksize can make the job complete much faster than using the default value of 1.
|
||||
:param min_length: The minimum string length context and candidate need to have in order to be scored.
|
||||
Returns 0.0 otherwise.
|
||||
:param boost_split_overlaps: Whether to boost split overlaps (e.g. [AB] <-> [BC]) that result from different preprocessing params.
|
||||
If we detect that the score is near a half match and the matching part of the candidate is at its boundaries
|
||||
we cut the context on the same side, recalculate the score and take the mean of both.
|
||||
Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total.
|
||||
"""
|
||||
pool: Optional[multiprocessing.pool.Pool] = None
|
||||
try:
|
||||
score_candidate_args = ((context, candidate, min_length, boost_split_overlaps) for candidate in candidates)
|
||||
if num_processes is None or num_processes > 1:
|
||||
pool = Pool(processes=num_processes)
|
||||
candidate_scores: Iterable = pool.imap_unordered(
|
||||
_score_candidate, score_candidate_args, chunksize=chunksize
|
||||
)
|
||||
else:
|
||||
candidate_scores = map(_score_candidate, score_candidate_args)
|
||||
|
||||
if show_progress:
|
||||
candidate_scores = tqdm(candidate_scores)
|
||||
|
||||
matches = (candidate for candidate in candidate_scores if candidate.score > threshold)
|
||||
sorted_matches = sorted(matches, key=lambda candidate: candidate.score, reverse=True)
|
||||
match_list = list((candidate_score.candidate_id, candidate_score.score) for candidate_score in sorted_matches)
|
||||
|
||||
return match_list
|
||||
|
||||
finally:
|
||||
if pool:
|
||||
pool.close()
|
||||
|
||||
|
||||
def match_contexts(
|
||||
contexts: List[str],
|
||||
candidates: Generator[Tuple[str, str], None, None],
|
||||
threshold: float = 65.0,
|
||||
show_progress: bool = False,
|
||||
num_processes: int = None,
|
||||
chunksize: int = 1,
|
||||
min_length: int = 100,
|
||||
boost_split_overlaps: bool = True,
|
||||
) -> List[List[Tuple[str, float]]]:
|
||||
"""
|
||||
Matches the contexts against multiple candidates. Candidates consist of a tuple of an id and its string text.
|
||||
This method iterates over candidates only once.
|
||||
|
||||
Returns for each context a sorted list of the candidate ids and its scores filtered by the threshold in descending order.
|
||||
|
||||
:param contexts: The contexts to match.
|
||||
:param candidates: The candidates to match the context.
|
||||
A candidate consists of a tuple of candidate id and candidate text.
|
||||
:param threshold: Score threshold that candidates must surpass to be included into the result list.
|
||||
:param show_progress: Whether to show the progress of matching all candidates.
|
||||
:param num_processes: The number of processes to be used for matching in parallel.
|
||||
:param chunksize: The chunksize used during parallel processing.
|
||||
If not specified chunksize is 1.
|
||||
For very long iterables using a large value for chunksize can make the job complete much faster than using the default value of 1.
|
||||
:param min_length: The minimum string length context and candidate need to have in order to be scored.
|
||||
Returns 0.0 otherwise.
|
||||
:param boost_split_overlaps: Whether to boost split overlaps (e.g. [AB] <-> [BC]) that result from different preprocessing params.
|
||||
If we detect that the score is near a half match and the matching part of the candidate is at its boundaries
|
||||
we cut the context on the same side, recalculate the score and take the mean of both.
|
||||
Thus [AB] <-> [BC] (score ~50) gets recalculated with B <-> B (score ~100) scoring ~75 in total.
|
||||
"""
|
||||
pool: Optional[multiprocessing.pool.Pool] = None
|
||||
try:
|
||||
score_candidate_args = (
|
||||
(context, candidate, min_length, boost_split_overlaps)
|
||||
for candidate in candidates
|
||||
for context in enumerate(contexts)
|
||||
)
|
||||
|
||||
if num_processes is None or num_processes > 1:
|
||||
pool = Pool(processes=num_processes)
|
||||
candidate_scores: Iterable = pool.imap_unordered(
|
||||
_score_candidate, score_candidate_args, chunksize=chunksize
|
||||
)
|
||||
else:
|
||||
candidate_scores = map(_score_candidate, score_candidate_args)
|
||||
|
||||
if show_progress:
|
||||
candidate_scores = tqdm(candidate_scores)
|
||||
|
||||
match_lists: List[List[Tuple[str, float]]] = list()
|
||||
matches = (candidate for candidate in candidate_scores if candidate.score > threshold)
|
||||
group_sorted_matches = sorted(matches, key=lambda candidate: candidate.context_id)
|
||||
grouped_matches = groupby(group_sorted_matches, key=lambda candidate: candidate.context_id)
|
||||
for context_id, group in grouped_matches:
|
||||
sorted_group = sorted(group, key=lambda candidate: candidate.score, reverse=True)
|
||||
match_list = list((candiate_score.candidate_id, candiate_score.score) for candiate_score in sorted_group)
|
||||
match_lists.insert(context_id, match_list)
|
||||
|
||||
return match_lists
|
||||
|
||||
finally:
|
||||
if pool:
|
||||
pool.close()
|
||||
@ -96,6 +96,9 @@ install_requires =
|
||||
elasticsearch>=7.7,<=7.10
|
||||
elastic-apm
|
||||
|
||||
# context matching
|
||||
rapidfuzz
|
||||
|
||||
# Schema validation
|
||||
jsonschema
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
@ -6,9 +7,28 @@ from haystack.utils.preprocessing import convert_files_to_dicts, tika_convert_fi
|
||||
from haystack.utils.cleaning import clean_wiki_text
|
||||
from haystack.utils.augment_squad import augment_squad
|
||||
from haystack.utils.squad_data import SquadData
|
||||
from haystack.utils.context_matching import calculate_context_similarity, match_context, match_contexts
|
||||
|
||||
from .conftest import SAMPLES_PATH
|
||||
|
||||
TEST_CONTEXT = context = """Der Merkantilismus förderte Handel und Verkehr mit teils marktkonformen, teils dirigistischen Maßnahmen.
|
||||
An der Schwelle zum 19. Jahrhundert entstand ein neuer Typus des Nationalstaats, der die Säkularisation durchsetzte,
|
||||
moderne Bildungssysteme etablierte und die Industrialisierung vorantrieb.\n
|
||||
Beim Begriff der Aufklärung geht es auch um die Prozesse zwischen diesen frühneuzeitlichen Eckpunkten.
|
||||
Man versucht die fortschrittlichen Faktoren zu definieren, die in das 19. Jahrhundert führten.
|
||||
Widerstände gegen diesen Fortschritt werden anti-aufklärerischen Kräften oder unreflektierten Traditionen zugeordnet.
|
||||
Die Epochendefinition rückt vor allem publizistisch tätige Gruppen in den gesellschaftlichen Fokus,
|
||||
die zunächst selten einen bürgerlichen Hintergrund aufwiesen, sondern weitaus häufiger der Geistlichkeit oder Aristokratie angehörten:
|
||||
Wissenschaftler, Journalisten, Autoren, sogar Regenten, die Traditionen der Kritik unterzogen, indem sie sich auf die Vernunftperspektive beriefen."""
|
||||
|
||||
|
||||
TEST_CONTEXT_2 = """Beer is one of the oldest[1][2][3] and most widely consumed[4] alcoholic drinks in the world, and the third most popular drink overall after water and tea.[5] It is produced by the brewing and fermentation of starches, mainly derived from cereal grains—most commonly from malted barley, though wheat, maize (corn), rice, and oats are also used. During the brewing process, fermentation of the starch sugars in the wort produces ethanol and carbonation in the resulting beer.[6] Most modern beer is brewed with hops, which add bitterness and other flavours and act as a natural preservative and stabilizing agent. Other flavouring agents such as gruit, herbs, or fruits may be included or used instead of hops. In commercial brewing, the natural carbonation effect is often removed during processing and replaced with forced carbonation.[7]
|
||||
Some of humanity's earliest known writings refer to the production and distribution of beer: the Code of Hammurabi included laws regulating beer and beer parlours,[8] and "The Hymn to Ninkasi", a prayer to the Mesopotamian goddess of beer, served as both a prayer and as a method of remembering the recipe for beer in a culture with few literate people.[9][10]
|
||||
Beer is distributed in bottles and cans and is also commonly available on draught, particularly in pubs and bars. The brewing industry is a global business, consisting of several dominant multinational companies and many thousands of smaller producers ranging from brewpubs to regional breweries. The strength of modern beer is usually around 4% to 6% alcohol by volume (ABV), although it may vary between 0.5% and 20%, with some breweries creating examples of 40% ABV and above.[11]
|
||||
Beer forms part of the culture of many nations and is associated with social traditions such as beer festivals, as well as a rich pub culture involving activities like pub crawling, pub quizzes and pub games.
|
||||
When beer is distilled, the resulting liquor is a form of whisky.[12]
|
||||
"""
|
||||
|
||||
|
||||
def test_convert_files_to_dicts():
|
||||
documents = convert_files_to_dicts(
|
||||
@ -69,3 +89,170 @@ def test_squad_to_df():
|
||||
result = SquadData.df_to_data(df)
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_calculate_context_similarity_on_parts_of_whole_document():
|
||||
whole_document = TEST_CONTEXT
|
||||
min_length = 100
|
||||
margin = 5
|
||||
context_size = min_length + margin
|
||||
for i in range(len(whole_document) - context_size):
|
||||
partial_context = whole_document[i : i + context_size]
|
||||
score = calculate_context_similarity(partial_context, whole_document, min_length=min_length)
|
||||
assert score == 100.0
|
||||
|
||||
|
||||
def test_calculate_context_similarity_on_parts_of_whole_document_different_case():
|
||||
whole_document = TEST_CONTEXT
|
||||
min_length = 100
|
||||
margin = 5
|
||||
context_size = min_length + margin
|
||||
for i in range(len(whole_document) - context_size):
|
||||
partial_context = whole_document[i : i + context_size].lower()
|
||||
score = calculate_context_similarity(partial_context, whole_document, min_length=min_length)
|
||||
assert score == 100.0
|
||||
|
||||
|
||||
def test_calculate_context_similarity_on_parts_of_whole_document_different_whitesapce():
|
||||
whole_document = TEST_CONTEXT
|
||||
words = whole_document.split()
|
||||
min_length = 100
|
||||
context_word_size = 20
|
||||
for i in range(len(words) - context_word_size):
|
||||
partial_context = "\n\t\t\t".join(words[i : i + context_word_size])
|
||||
score = calculate_context_similarity(partial_context, whole_document, min_length=min_length)
|
||||
assert score == 100.0
|
||||
|
||||
|
||||
def test_calculate_context_similarity_min_length():
|
||||
whole_document = TEST_CONTEXT
|
||||
min_length = 100
|
||||
context_size = min_length - 1
|
||||
for i in range(len(whole_document) - context_size):
|
||||
partial_context = whole_document[i : i + context_size]
|
||||
score = calculate_context_similarity(partial_context, whole_document, min_length=min_length)
|
||||
assert score == 0.0
|
||||
|
||||
|
||||
def test_calculate_context_similarity_on_partially_overlapping_contexts():
|
||||
whole_document = TEST_CONTEXT
|
||||
min_length = 100
|
||||
margin = 5
|
||||
context_size = min_length + margin
|
||||
stride = context_size // 2
|
||||
for i in range(len(whole_document) - context_size - stride):
|
||||
partial_context_1 = whole_document[i : i + context_size]
|
||||
partial_context_2 = whole_document[i + stride : i + stride + context_size]
|
||||
score = calculate_context_similarity(partial_context_1, partial_context_2, min_length=min_length)
|
||||
assert score >= 65.0
|
||||
|
||||
|
||||
def test_calculate_context_similarity_on_non_matching_contexts():
|
||||
whole_document = TEST_CONTEXT
|
||||
min_length = 100
|
||||
margin = 5
|
||||
context_size = min_length + margin
|
||||
scores = []
|
||||
for i in range(len(whole_document) - context_size):
|
||||
partial_context = whole_document[i : i + context_size // 2] + _get_random_chars(context_size // 2)
|
||||
score = calculate_context_similarity(partial_context, whole_document, min_length=min_length)
|
||||
scores.append(score)
|
||||
for i in range(len(whole_document) - context_size):
|
||||
partial_context = (
|
||||
_get_random_chars(context_size // 2) + whole_document[i + context_size // 2 : i + context_size]
|
||||
)
|
||||
score = calculate_context_similarity(partial_context, whole_document, min_length=min_length)
|
||||
scores.append(score)
|
||||
accuracy = np.where(np.array(scores) < 65, 1, 0).mean()
|
||||
assert accuracy > 0.99
|
||||
|
||||
|
||||
def test_calculate_context_similarity_on_parts_of_whole_document_with_noise():
|
||||
whole_document = TEST_CONTEXT
|
||||
min_length = 100
|
||||
margin = 5
|
||||
context_size = min_length + margin
|
||||
for i in range(len(whole_document) - context_size):
|
||||
partial_context = _insert_noise(whole_document[i : i + context_size], 0.1)
|
||||
score = calculate_context_similarity(partial_context, whole_document, min_length=min_length)
|
||||
assert score >= 85.0
|
||||
|
||||
|
||||
def test_calculate_context_similarity_on_partially_overlapping_contexts_with_noise():
|
||||
whole_document = TEST_CONTEXT
|
||||
min_length = 100
|
||||
margin = 5
|
||||
context_size = min_length + margin
|
||||
stride = context_size // 2
|
||||
scores = []
|
||||
for i in range(len(whole_document) - context_size - stride):
|
||||
partial_context_1 = whole_document[i : i + context_size]
|
||||
partial_context_2 = _insert_noise(whole_document[i + stride : i + stride + context_size], 0.1)
|
||||
score = calculate_context_similarity(partial_context_1, partial_context_2, min_length=min_length)
|
||||
scores.append(score)
|
||||
accuracy = np.where(np.array(scores) >= 65, 1, 0).mean()
|
||||
assert accuracy > 0.99
|
||||
|
||||
|
||||
def test_match_context():
|
||||
whole_document = TEST_CONTEXT
|
||||
min_length = 100
|
||||
margin = 5
|
||||
context_size = min_length + margin
|
||||
for i in range(len(whole_document) - context_size):
|
||||
partial_context = whole_document[i : i + context_size]
|
||||
candidates = ((str(i), TEST_CONTEXT if i == 0 else TEST_CONTEXT_2) for i in range(10))
|
||||
results = match_context(partial_context, candidates, min_length=min_length, num_processes=2)
|
||||
assert len(results) == 1
|
||||
id, score = results[0]
|
||||
assert id == "0"
|
||||
assert score == 100.0
|
||||
|
||||
|
||||
def test_match_context_single_process():
|
||||
whole_document = TEST_CONTEXT
|
||||
min_length = 100
|
||||
margin = 5
|
||||
context_size = min_length + margin
|
||||
for i in range(len(whole_document) - context_size):
|
||||
partial_context = whole_document[i : i + context_size]
|
||||
candidates = ((str(i), TEST_CONTEXT if i == 0 else TEST_CONTEXT_2) for i in range(10))
|
||||
results = match_context(partial_context, candidates, min_length=min_length, num_processes=1)
|
||||
assert len(results) == 1
|
||||
id, score = results[0]
|
||||
assert id == "0"
|
||||
assert score == 100.0
|
||||
|
||||
|
||||
def test_match_contexts():
|
||||
whole_document = TEST_CONTEXT
|
||||
min_length = 100
|
||||
margin = 5
|
||||
context_size = min_length + margin
|
||||
candidates = ((str(i), TEST_CONTEXT if i == 0 else TEST_CONTEXT_2) for i in range(10))
|
||||
partial_contexts = [whole_document[i : i + context_size] for i in range(len(whole_document) - context_size)]
|
||||
result_list = match_contexts(partial_contexts, candidates, min_length=min_length, num_processes=2)
|
||||
assert len(result_list) == len(partial_contexts)
|
||||
for results in result_list:
|
||||
assert len(results) == 1
|
||||
id, score = results[0]
|
||||
assert id == "0"
|
||||
assert score == 100.0
|
||||
|
||||
|
||||
def _get_random_chars(size: int):
|
||||
chars = np.random.choice(
|
||||
list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZß?/.,;:-#äöüÄÖÜ+*~1234567890$€%&!§ "), size=size
|
||||
)
|
||||
return "".join(list(chars))
|
||||
|
||||
|
||||
def _insert_noise(input: str, ratio):
|
||||
size = int(ratio * len(input))
|
||||
insert_idxs = sorted(np.random.choice(range(len(input)), size=size, replace=False), reverse=True)
|
||||
insert_chars = np.random.choice(
|
||||
list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZß?/.,;:-#äöüÄÖÜ+*~1234567890$€%&!§"), size=size
|
||||
)
|
||||
for idx, char in zip(insert_idxs, insert_chars):
|
||||
input = input[:idx] + char + input[idx:]
|
||||
return input
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user