mirror of
https://github.com/allenai/olmocr.git
synced 2026-01-08 13:22:25 +00:00
Adding a trailing repetition test
This commit is contained in:
parent
07466e1ae4
commit
9be696fa30
@ -20,7 +20,7 @@ import sys
|
||||
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
from .tests import BasePDFTest, load_tests
|
||||
from .tests import BasePDFTest, RepetitionTest, load_tests
|
||||
from .utils import calculate_bootstrap_ci, perform_permutation_test
|
||||
|
||||
def evaluate_candidate(
|
||||
@ -117,6 +117,12 @@ def main():
|
||||
default=os.path.join(os.path.dirname(__file__), "sample_data"),
|
||||
help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--candidate",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Run test only for a single candidate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bootstrap_samples",
|
||||
type=int,
|
||||
@ -131,16 +137,14 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--permutation_tests",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="Number of permutations for statistical test (default: 10000).",
|
||||
action="store_true",
|
||||
help="Run permutation testing",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
input_folder = args.input_folder
|
||||
n_bootstrap = args.bootstrap_samples
|
||||
ci_level = args.confidence_level
|
||||
n_permutations = args.permutation_tests
|
||||
pdf_folder = os.path.join(input_folder, "pdfs")
|
||||
|
||||
# Check that the pdfs folder exists
|
||||
@ -173,17 +177,28 @@ def main():
|
||||
print("No valid tests found. Exiting.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Add in a default repeat test for every PDF that doesn't already have one
|
||||
for pdf in pdf_basenames:
|
||||
if not any(t.type == "repeat" for t in all_tests if t.pdf == pdf):
|
||||
all_tests.append(RepetitionTest(id=f"{pdf}_repeat", pdf=pdf, page=1, type="repeat"))
|
||||
|
||||
# Identify candidate pipeline folders (subdirectories of input_folder excluding /pdfs)
|
||||
candidate_folders = []
|
||||
for entry in os.listdir(input_folder):
|
||||
full_path = os.path.join(input_folder, entry)
|
||||
if os.path.isdir(full_path) and entry != "pdfs":
|
||||
candidate_folders.append(full_path)
|
||||
if args.candidate is not None:
|
||||
if entry == args.candidate:
|
||||
candidate_folders.append(full_path)
|
||||
else:
|
||||
if os.path.isdir(full_path) and entry != "pdfs":
|
||||
candidate_folders.append(full_path)
|
||||
|
||||
if not candidate_folders:
|
||||
print("Error: No candidate pipeline folders found (subdirectories besides 'pdfs').", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
candidate_folders.sort()
|
||||
|
||||
# Evaluate each candidate
|
||||
summary = []
|
||||
print("\nRunning tests for each candidate:")
|
||||
@ -238,62 +253,63 @@ def main():
|
||||
print("")
|
||||
|
||||
# Perform pairwise permutation tests
|
||||
print("\n" + "=" * 60)
|
||||
print("Pairwise Permutation Tests:")
|
||||
|
||||
valid_candidates = [c for c in summary if not c[3]] # Filter out candidates with errors
|
||||
olmocr_candidates = sorted([c for c in valid_candidates if "olmocr" in c[0].lower()], key=lambda x: x[1], reverse=True)
|
||||
non_olmocr_candidates = sorted([c for c in valid_candidates if "olmocr" not in c[0].lower()], key=lambda x: x[1], reverse=True)
|
||||
|
||||
top_olmocr = olmocr_candidates[0] if olmocr_candidates else None
|
||||
top_non_olmocr = non_olmocr_candidates[0] if non_olmocr_candidates else None
|
||||
top_two_olmocr = olmocr_candidates[:2]
|
||||
if args.permutation_tests:
|
||||
print("\n" + "=" * 60)
|
||||
print("Pairwise Permutation Tests:")
|
||||
|
||||
valid_candidates = [c for c in summary if not c[3]] # Filter out candidates with errors
|
||||
olmocr_candidates = sorted([c for c in valid_candidates if "olmocr" in c[0].lower()], key=lambda x: x[1], reverse=True)
|
||||
non_olmocr_candidates = sorted([c for c in valid_candidates if "olmocr" not in c[0].lower()], key=lambda x: x[1], reverse=True)
|
||||
|
||||
top_olmocr = olmocr_candidates[0] if olmocr_candidates else None
|
||||
top_non_olmocr = non_olmocr_candidates[0] if non_olmocr_candidates else None
|
||||
top_two_olmocr = olmocr_candidates[:2]
|
||||
|
||||
# Test 1: Top olmocr vs Top non-olmocr
|
||||
if top_olmocr and top_non_olmocr:
|
||||
olmocr_name, olmocr_score = top_olmocr[0], top_olmocr[1]
|
||||
non_olmocr_name, non_olmocr_score = top_non_olmocr[0], top_non_olmocr[1]
|
||||
olmocr_scores = top_olmocr[7] # all_test_scores
|
||||
non_olmocr_scores = top_non_olmocr[7] # all_test_scores
|
||||
|
||||
diff, p_value = perform_permutation_test(
|
||||
olmocr_scores, non_olmocr_scores, n_permutations=n_permutations
|
||||
)
|
||||
|
||||
print(f"\nComparison 1: Top olmocr vs Top non-olmocr candidate")
|
||||
print(f" {olmocr_name} ({olmocr_score*100:.1f}%) vs {non_olmocr_name} ({non_olmocr_score*100:.1f}%)")
|
||||
print(f" Difference: {diff*100:.2f}% (positive means {olmocr_name} is better)")
|
||||
print(f" p-value: {p_value:.4f}")
|
||||
if p_value < 0.05:
|
||||
print(f" Result: Statistically significant difference (p < 0.05)")
|
||||
# Test 1: Top olmocr vs Top non-olmocr
|
||||
if top_olmocr and top_non_olmocr:
|
||||
olmocr_name, olmocr_score = top_olmocr[0], top_olmocr[1]
|
||||
non_olmocr_name, non_olmocr_score = top_non_olmocr[0], top_non_olmocr[1]
|
||||
olmocr_scores = top_olmocr[7] # all_test_scores
|
||||
non_olmocr_scores = top_non_olmocr[7] # all_test_scores
|
||||
|
||||
diff, p_value = perform_permutation_test(
|
||||
olmocr_scores, non_olmocr_scores
|
||||
)
|
||||
|
||||
print(f"\nComparison 1: Top olmocr vs Top non-olmocr candidate")
|
||||
print(f" {olmocr_name} ({olmocr_score*100:.1f}%) vs {non_olmocr_name} ({non_olmocr_score*100:.1f}%)")
|
||||
print(f" Difference: {diff*100:.2f}% (positive means {olmocr_name} is better)")
|
||||
print(f" p-value: {p_value:.4f}")
|
||||
if p_value < 0.05:
|
||||
print(f" Result: Statistically significant difference (p < 0.05)")
|
||||
else:
|
||||
print(f" Result: No statistically significant difference (p ≥ 0.05)")
|
||||
else:
|
||||
print(f" Result: No statistically significant difference (p ≥ 0.05)")
|
||||
else:
|
||||
print("\nCannot perform olmocr vs non-olmocr comparison: Missing candidates")
|
||||
|
||||
# Test 2: Top two olmocr candidates (if there are at least two)
|
||||
if len(top_two_olmocr) >= 2:
|
||||
olmocr1_name, olmocr1_score = top_two_olmocr[0][0], top_two_olmocr[0][1]
|
||||
olmocr2_name, olmocr2_score = top_two_olmocr[1][0], top_two_olmocr[1][1]
|
||||
olmocr1_scores = top_two_olmocr[0][7] # all_test_scores
|
||||
olmocr2_scores = top_two_olmocr[1][7] # all_test_scores
|
||||
print("\nCannot perform olmocr vs non-olmocr comparison: Missing candidates")
|
||||
|
||||
diff, p_value = perform_permutation_test(
|
||||
olmocr1_scores, olmocr2_scores, n_permutations=n_permutations
|
||||
)
|
||||
|
||||
print(f"\nComparison 2: Top two olmocr candidates")
|
||||
print(f" {olmocr1_name} ({olmocr1_score*100:.1f}%) vs {olmocr2_name} ({olmocr2_score*100:.1f}%)")
|
||||
print(f" Difference: {diff*100:.2f}% (positive means {olmocr1_name} is better)")
|
||||
print(f" p-value: {p_value:.4f}")
|
||||
if p_value < 0.05:
|
||||
print(f" Result: Statistically significant difference (p < 0.05)")
|
||||
# Test 2: Top two olmocr candidates (if there are at least two)
|
||||
if len(top_two_olmocr) >= 2:
|
||||
olmocr1_name, olmocr1_score = top_two_olmocr[0][0], top_two_olmocr[0][1]
|
||||
olmocr2_name, olmocr2_score = top_two_olmocr[1][0], top_two_olmocr[1][1]
|
||||
olmocr1_scores = top_two_olmocr[0][7] # all_test_scores
|
||||
olmocr2_scores = top_two_olmocr[1][7] # all_test_scores
|
||||
|
||||
diff, p_value = perform_permutation_test(
|
||||
olmocr1_scores, olmocr2_scores
|
||||
)
|
||||
|
||||
print(f"\nComparison 2: Top two olmocr candidates")
|
||||
print(f" {olmocr1_name} ({olmocr1_score*100:.1f}%) vs {olmocr2_name} ({olmocr2_score*100:.1f}%)")
|
||||
print(f" Difference: {diff*100:.2f}% (positive means {olmocr1_name} is better)")
|
||||
print(f" p-value: {p_value:.4f}")
|
||||
if p_value < 0.05:
|
||||
print(f" Result: Statistically significant difference (p < 0.05)")
|
||||
else:
|
||||
print(f" Result: No statistically significant difference (p ≥ 0.05)")
|
||||
else:
|
||||
print(f" Result: No statistically significant difference (p ≥ 0.05)")
|
||||
else:
|
||||
print("\nCannot perform top two olmocr comparison: Not enough olmocr candidates")
|
||||
|
||||
print("=" * 60)
|
||||
print("\nCannot perform top two olmocr comparison: Not enough olmocr candidates")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -22,6 +22,9 @@
|
||||
|
||||
{"pdf": "openstax_caculus_pg_273.pdf", "page": 1, "id": "openstax_caculus_pg_273_minediff_02", "type": "present", "checked": "verified", "text": "Use the graph of the position function to determine the time intervals when the velocity is positive, negative, or zero."}
|
||||
{"pdf": "openstax_caculus_pg_273.pdf", "page": 1, "id": "openstax_caculus_pg_273_minediff_03", "type": "present", "checked": "verified", "text": "Use the graph of the velocity function to determine the time intervals when the acceleration is positive, negative, or zero."}
|
||||
{"pdf": "openstax_caculus_pg_273.pdf", "page": 1, "id": "openstax_caculus_pg_273_minediff_04", "type": "order", "before": "150.", "after": "157."}
|
||||
{"pdf": "openstax_caculus_pg_273.pdf", "page": 1, "id": "openstax_caculus_pg_273_minediff_05", "type": "order", "before": "150.", "after": "158."}
|
||||
{"pdf": "openstax_caculus_pg_273.pdf", "page": 1, "id": "openstax_caculus_pg_273_minediff_06", "type": "order", "before": "150.", "after": "159."}
|
||||
|
||||
{"pdf": "multi_column_miss.pdf", "page": 1, "id": "multi_column_miss_minediff_01", "type": "present", "checked": "verified", "text": "This report first provides the context and development of CSR; then, from internal company documents, examines how PM came to its own version."}
|
||||
{"pdf": "multi_column_miss.pdf", "page": 1, "id": "multi_column_miss_minediff_02", "type": "present", "checked": "verified", "text": "This paper examines whether a tobacco company espousing CSR should be judged simply as a corporate entity along standards of business ethics, or as an irretrievably negative force in the realm of public health, thereby rendering CSR an oxymoron."}
|
||||
@ -39,20 +42,18 @@
|
||||
{"pdf": "olmo2-pg4.pdf", "page": 1, "id": "olmo2-pg4_table08", "type": "table", "cell": "Math proofs code", "left_heading": "Algebraic Stack"}
|
||||
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "Quadratic regression", "left": "Challenge"}
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "Instrument Use", "left": "Normal"}
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "0.87", "top_heading": "Procedure"}
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "0.87", "top_heading": "ReACT"}
|
||||
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "Pick-and-place object", "left_heading": "27"}
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "0.66", "right": "0.44"}
|
||||
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "Interact with a moving agent", "top_heading": "Unit Test Topic"}
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t01", "type": "table", "cell": "Instrument Use", "left": "Normal"}
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t02", "type": "table", "cell": "0.87", "top_heading": "Procedure"}
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t03", "type": "table", "cell": "0.87", "top_heading": "ReACT"}
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t04", "type": "table", "cell": "Pick-and-place object", "left_heading": "27"}
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t05", "type": "table", "cell": "0.66", "right": "0.44"}
|
||||
{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t06", "type": "table", "cell": "Interact with a moving agent", "top_heading": "Unit Test Topic"}
|
||||
|
||||
{"pdf": "earnings.pdf", "page": 1, "id": "earnings_table00", "type": "table", "cell": "1,136", "top_heading": "Year Ended"}
|
||||
{"pdf": "earnings.pdf", "page": 1, "id": "earnings_table01", "type": "table", "cell": "Year Ended"}
|
||||
{"pdf": "earnings.pdf", "page": 1, "id": "earnings_table02", "type": "table", "cell": "680", "up": "1,892"}
|
||||
{"pdf": "earnings.pdf", "page": 1, "id": "earnings_table02", "type": "table", "cell": "2,532", "left_heading": "Research and development"}
|
||||
|
||||
{"pdf": "earnings.pdf", "page": 1, "id": "earnings_table03", "type": "table", "cell": "2,532", "left_heading": "Research and development"}
|
||||
{"pdf": "earnings.pdf", "page": 1, "id": "earnings_table04", "type": "absent", "text": "62"}
|
||||
|
||||
|
||||
|
||||
|
||||
@ -10,96 +10,7 @@ from typing import List, Optional, Tuple, Dict, Any
|
||||
from fuzzysearch import find_near_matches
|
||||
from rapidfuzz import fuzz
|
||||
|
||||
|
||||
def parse_markdown_tables(md_content: str) -> List[np.ndarray]:
|
||||
"""
|
||||
Extract and parse all markdown tables from the provided content.
|
||||
|
||||
Args:
|
||||
md_content: The markdown content containing tables
|
||||
|
||||
Returns:
|
||||
A list of numpy arrays, each representing a parsed table
|
||||
"""
|
||||
# Extract all tables from markdown
|
||||
table_pattern = r'(\|(?:[^|]*\|)+)\s*\n\|(?:[:-]+\|)+\s*\n((?:\|(?:[^|]*\|)+\s*\n)+)'
|
||||
table_matches = re.finditer(table_pattern, md_content)
|
||||
|
||||
parsed_tables = []
|
||||
|
||||
for table_match in table_matches:
|
||||
# Extract header and body from the table match
|
||||
header_row = table_match.group(1).strip()
|
||||
body_rows = table_match.group(2).strip().split('\n')
|
||||
|
||||
# Process header and rows to remove leading/trailing |
|
||||
header_cells = [cell.strip() for cell in header_row.split('|')]
|
||||
if header_cells[0] == '':
|
||||
header_cells = header_cells[1:]
|
||||
if header_cells[-1] == '':
|
||||
header_cells = header_cells[:-1]
|
||||
|
||||
# Process table body rows
|
||||
table_data = []
|
||||
for row in [header_row] + body_rows:
|
||||
if '|' not in row: # Skip separator row
|
||||
continue
|
||||
|
||||
cells = [cell.strip() for cell in row.split('|')]
|
||||
if cells[0] == '':
|
||||
cells = cells[1:]
|
||||
if cells[-1] == '':
|
||||
cells = cells[:-1]
|
||||
|
||||
table_data.append(cells)
|
||||
|
||||
# Skip separator row (second row with dashes)
|
||||
if len(table_data) > 1 and all('-' in cell for cell in table_data[1]):
|
||||
table_data = [table_data[0]] + table_data[2:]
|
||||
|
||||
# Convert to numpy array for easier manipulation
|
||||
# First ensure all rows have the same number of columns by padding if necessary
|
||||
max_cols = max(len(row) for row in table_data)
|
||||
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
|
||||
table_array = np.array(padded_data)
|
||||
|
||||
parsed_tables.append(table_array)
|
||||
|
||||
return parsed_tables
|
||||
|
||||
|
||||
def parse_html_tables(html_content: str) -> List[np.ndarray]:
|
||||
"""
|
||||
Extract and parse all HTML tables from the provided content.
|
||||
|
||||
Args:
|
||||
html_content: The HTML content containing tables
|
||||
|
||||
Returns:
|
||||
A list of numpy arrays, each representing a parsed table
|
||||
"""
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
tables = soup.find_all('table')
|
||||
|
||||
parsed_tables = []
|
||||
|
||||
for table in tables:
|
||||
rows = table.find_all(['tr'])
|
||||
table_data = []
|
||||
|
||||
for row in rows:
|
||||
cells = row.find_all(['th', 'td'])
|
||||
row_data = [cell.get_text().strip() for cell in cells]
|
||||
table_data.append(row_data)
|
||||
|
||||
# Ensure all rows have the same number of columns
|
||||
if table_data:
|
||||
max_cols = max(len(row) for row in table_data)
|
||||
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
|
||||
table_array = np.array(padded_data)
|
||||
parsed_tables.append(table_array)
|
||||
|
||||
return parsed_tables
|
||||
from olmocr.repeatdetect import RepeatDetector
|
||||
|
||||
|
||||
class TestType(str, Enum):
|
||||
@ -107,6 +18,7 @@ class TestType(str, Enum):
|
||||
ABSENT = "absent"
|
||||
ORDER = "order"
|
||||
TABLE = "table"
|
||||
REPEAT = "repeat"
|
||||
|
||||
|
||||
class TestChecked(str, Enum):
|
||||
@ -239,8 +151,95 @@ class TextOrderTest(BasePDFTest):
|
||||
return True, ""
|
||||
return False, (f"Could not find a location where '{self.before[:40]}...' appears before " f"'{self.after[:40]}...'.")
|
||||
|
||||
def parse_markdown_tables(md_content: str) -> List[np.ndarray]:
|
||||
"""
|
||||
Extract and parse all markdown tables from the provided content.
|
||||
|
||||
Args:
|
||||
md_content: The markdown content containing tables
|
||||
|
||||
Returns:
|
||||
A list of numpy arrays, each representing a parsed table
|
||||
"""
|
||||
# Extract all tables from markdown
|
||||
table_pattern = r'(\|(?:[^|]*\|)+)\s*\n\|(?:[:-]+\|)+\s*\n((?:\|(?:[^|]*\|)+\s*\n)+)'
|
||||
table_matches = re.finditer(table_pattern, md_content)
|
||||
|
||||
parsed_tables = []
|
||||
|
||||
for table_match in table_matches:
|
||||
# Extract header and body from the table match
|
||||
header_row = table_match.group(1).strip()
|
||||
body_rows = table_match.group(2).strip().split('\n')
|
||||
|
||||
# Process header and rows to remove leading/trailing |
|
||||
header_cells = [cell.strip() for cell in header_row.split('|')]
|
||||
if header_cells[0] == '':
|
||||
header_cells = header_cells[1:]
|
||||
if header_cells[-1] == '':
|
||||
header_cells = header_cells[:-1]
|
||||
|
||||
# Process table body rows
|
||||
table_data = []
|
||||
for row in [header_row] + body_rows:
|
||||
if '|' not in row: # Skip separator row
|
||||
continue
|
||||
|
||||
cells = [cell.strip() for cell in row.split('|')]
|
||||
if cells[0] == '':
|
||||
cells = cells[1:]
|
||||
if cells[-1] == '':
|
||||
cells = cells[:-1]
|
||||
|
||||
table_data.append(cells)
|
||||
|
||||
# Skip separator row (second row with dashes)
|
||||
if len(table_data) > 1 and all('-' in cell for cell in table_data[1]):
|
||||
table_data = [table_data[0]] + table_data[2:]
|
||||
|
||||
# Convert to numpy array for easier manipulation
|
||||
# First ensure all rows have the same number of columns by padding if necessary
|
||||
max_cols = max(len(row) for row in table_data)
|
||||
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
|
||||
table_array = np.array(padded_data)
|
||||
|
||||
parsed_tables.append(table_array)
|
||||
|
||||
return parsed_tables
|
||||
|
||||
|
||||
def parse_html_tables(html_content: str) -> List[np.ndarray]:
|
||||
"""
|
||||
Extract and parse all HTML tables from the provided content.
|
||||
|
||||
Args:
|
||||
html_content: The HTML content containing tables
|
||||
|
||||
Returns:
|
||||
A list of numpy arrays, each representing a parsed table
|
||||
"""
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
tables = soup.find_all('table')
|
||||
|
||||
parsed_tables = []
|
||||
|
||||
for table in tables:
|
||||
rows = table.find_all(['tr'])
|
||||
table_data = []
|
||||
|
||||
for row in rows:
|
||||
cells = row.find_all(['th', 'td'])
|
||||
row_data = [cell.get_text().strip() for cell in cells]
|
||||
table_data.append(row_data)
|
||||
|
||||
# Ensure all rows have the same number of columns
|
||||
if table_data:
|
||||
max_cols = max(len(row) for row in table_data)
|
||||
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
|
||||
table_array = np.array(padded_data)
|
||||
parsed_tables.append(table_array)
|
||||
|
||||
return parsed_tables
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -401,6 +400,23 @@ class TableTest(BasePDFTest):
|
||||
return False, f"Found cells matching '{self.cell}' but relationships were not satisfied: {'; '.join(failed_reasons)}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepetitionTest(BasePDFTest):
|
||||
max_repeats: int=10
|
||||
|
||||
def run(self, content: str) -> Tuple[bool, str]:
|
||||
# Makes sure that the content has no egregious repeated ngrams at the end, which indicate a degradation of quality
|
||||
d = RepeatDetector(max_ngram_size=5)
|
||||
d.add_letters(content)
|
||||
repeats = d.ngram_repeats()
|
||||
|
||||
for index, count in enumerate(repeats):
|
||||
if count > self.max_repeats:
|
||||
return False, f"Text ends with {count} repeating {index+1}-grams, invalid"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def load_tests(jsonl_file: str) -> List[BasePDFTest]:
|
||||
"""
|
||||
Load tests from a JSONL file.
|
||||
@ -412,6 +428,7 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]:
|
||||
A list of test objects.
|
||||
"""
|
||||
tests: List[BasePDFTest] = []
|
||||
unique_ids = set()
|
||||
with open(jsonl_file, "r") as file:
|
||||
for line_number, line in enumerate(file, start=1):
|
||||
line = line.strip()
|
||||
@ -430,6 +447,11 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]:
|
||||
else:
|
||||
raise ValidationError(f"Unknown test type: {test_type}")
|
||||
|
||||
if test.id in unique_ids:
|
||||
raise ValidationError(f"Test with duplicate id {test.id} found, error loading tests.")
|
||||
else:
|
||||
unique_ids.add(test.id)
|
||||
|
||||
tests.append(test)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON on line {line_number}: {e}")
|
||||
|
||||
@ -2,7 +2,7 @@ import random
|
||||
import string
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import re
|
||||
|
||||
class RepeatDetector:
|
||||
def __init__(self, max_ngram_size: int = 10):
|
||||
@ -18,20 +18,23 @@ class RepeatDetector:
|
||||
if not self.data:
|
||||
return result
|
||||
|
||||
# Normalize all whitespace to single spaces
|
||||
text = re.sub(r'\s+', ' ', self.data)
|
||||
|
||||
# For each n-gram size
|
||||
for size in range(1, self.max_ngram_size + 1):
|
||||
if len(self.data) < size:
|
||||
if len(text) < size:
|
||||
continue
|
||||
|
||||
# Get the last n-gram
|
||||
target = self.data[-size:]
|
||||
target = text[-size:]
|
||||
|
||||
# Count backwards from the end to find repeats
|
||||
count = 0
|
||||
pos = len(self.data) - size # Start position for previous n-gram
|
||||
pos = len(text) - size # Start position for previous n-gram
|
||||
|
||||
while pos >= 0:
|
||||
if self.data[pos : pos + size] == target:
|
||||
if text[pos : pos + size] == target:
|
||||
count += 1
|
||||
pos -= size # Move back by the size of the n-gram
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user