From 9be696fa306ddaefbfd9ba87f43bfa4d9bdb81d3 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Thu, 6 Mar 2025 08:56:16 -0800 Subject: [PATCH] Adding a trailing repetition test --- olmocr/bench/benchmark.py | 134 ++++++++-------- olmocr/bench/sample_data/dataset.jsonl | 21 +-- olmocr/bench/tests.py | 202 ++++++++++++++----------- olmocr/repeatdetect.py | 13 +- 4 files changed, 206 insertions(+), 164 deletions(-) diff --git a/olmocr/bench/benchmark.py b/olmocr/bench/benchmark.py index ae7cf8a..51fd08d 100644 --- a/olmocr/bench/benchmark.py +++ b/olmocr/bench/benchmark.py @@ -20,7 +20,7 @@ import sys from typing import Dict, List, Tuple, Optional -from .tests import BasePDFTest, load_tests +from .tests import BasePDFTest, RepetitionTest, load_tests from .utils import calculate_bootstrap_ci, perform_permutation_test def evaluate_candidate( @@ -117,6 +117,12 @@ def main(): default=os.path.join(os.path.dirname(__file__), "sample_data"), help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.", ) + parser.add_argument( + "--candidate", + type=str, + default=None, + help="Run test only for a single candidate" + ) parser.add_argument( "--bootstrap_samples", type=int, @@ -131,16 +137,14 @@ def main(): ) parser.add_argument( "--permutation_tests", - type=int, - default=10000, - help="Number of permutations for statistical test (default: 10000).", + action="store_true", + help="Run permutation testing", ) args = parser.parse_args() input_folder = args.input_folder n_bootstrap = args.bootstrap_samples ci_level = args.confidence_level - n_permutations = args.permutation_tests pdf_folder = os.path.join(input_folder, "pdfs") # Check that the pdfs folder exists @@ -173,17 +177,28 @@ def main(): print("No valid tests found. Exiting.", file=sys.stderr) sys.exit(1) + # Add in a default repeat test for every PDF that doesn't already have one + for pdf in pdf_basenames: + if not any(t.type == "repeat" for t in all_tests if t.pdf == pdf): + all_tests.append(RepetitionTest(id=f"{pdf}_repeat", pdf=pdf, page=1, type="repeat")) + # Identify candidate pipeline folders (subdirectories of input_folder excluding /pdfs) candidate_folders = [] for entry in os.listdir(input_folder): full_path = os.path.join(input_folder, entry) - if os.path.isdir(full_path) and entry != "pdfs": - candidate_folders.append(full_path) + if args.candidate is not None: + if entry == args.candidate: + candidate_folders.append(full_path) + else: + if os.path.isdir(full_path) and entry != "pdfs": + candidate_folders.append(full_path) if not candidate_folders: print("Error: No candidate pipeline folders found (subdirectories besides 'pdfs').", file=sys.stderr) sys.exit(1) + candidate_folders.sort() + # Evaluate each candidate summary = [] print("\nRunning tests for each candidate:") @@ -238,62 +253,63 @@ def main(): print("") # Perform pairwise permutation tests - print("\n" + "=" * 60) - print("Pairwise Permutation Tests:") - - valid_candidates = [c for c in summary if not c[3]] # Filter out candidates with errors - olmocr_candidates = sorted([c for c in valid_candidates if "olmocr" in c[0].lower()], key=lambda x: x[1], reverse=True) - non_olmocr_candidates = sorted([c for c in valid_candidates if "olmocr" not in c[0].lower()], key=lambda x: x[1], reverse=True) - - top_olmocr = olmocr_candidates[0] if olmocr_candidates else None - top_non_olmocr = non_olmocr_candidates[0] if non_olmocr_candidates else None - top_two_olmocr = olmocr_candidates[:2] + if args.permutation_tests: + print("\n" + "=" * 60) + print("Pairwise Permutation Tests:") + + valid_candidates = [c for c in summary if not c[3]] # Filter out candidates with errors + olmocr_candidates = sorted([c for c in valid_candidates if "olmocr" in c[0].lower()], key=lambda x: x[1], reverse=True) + non_olmocr_candidates = sorted([c for c in valid_candidates if "olmocr" not in c[0].lower()], key=lambda x: x[1], reverse=True) + + top_olmocr = olmocr_candidates[0] if olmocr_candidates else None + top_non_olmocr = non_olmocr_candidates[0] if non_olmocr_candidates else None + top_two_olmocr = olmocr_candidates[:2] - # Test 1: Top olmocr vs Top non-olmocr - if top_olmocr and top_non_olmocr: - olmocr_name, olmocr_score = top_olmocr[0], top_olmocr[1] - non_olmocr_name, non_olmocr_score = top_non_olmocr[0], top_non_olmocr[1] - olmocr_scores = top_olmocr[7] # all_test_scores - non_olmocr_scores = top_non_olmocr[7] # all_test_scores - - diff, p_value = perform_permutation_test( - olmocr_scores, non_olmocr_scores, n_permutations=n_permutations - ) - - print(f"\nComparison 1: Top olmocr vs Top non-olmocr candidate") - print(f" {olmocr_name} ({olmocr_score*100:.1f}%) vs {non_olmocr_name} ({non_olmocr_score*100:.1f}%)") - print(f" Difference: {diff*100:.2f}% (positive means {olmocr_name} is better)") - print(f" p-value: {p_value:.4f}") - if p_value < 0.05: - print(f" Result: Statistically significant difference (p < 0.05)") + # Test 1: Top olmocr vs Top non-olmocr + if top_olmocr and top_non_olmocr: + olmocr_name, olmocr_score = top_olmocr[0], top_olmocr[1] + non_olmocr_name, non_olmocr_score = top_non_olmocr[0], top_non_olmocr[1] + olmocr_scores = top_olmocr[7] # all_test_scores + non_olmocr_scores = top_non_olmocr[7] # all_test_scores + + diff, p_value = perform_permutation_test( + olmocr_scores, non_olmocr_scores + ) + + print(f"\nComparison 1: Top olmocr vs Top non-olmocr candidate") + print(f" {olmocr_name} ({olmocr_score*100:.1f}%) vs {non_olmocr_name} ({non_olmocr_score*100:.1f}%)") + print(f" Difference: {diff*100:.2f}% (positive means {olmocr_name} is better)") + print(f" p-value: {p_value:.4f}") + if p_value < 0.05: + print(f" Result: Statistically significant difference (p < 0.05)") + else: + print(f" Result: No statistically significant difference (p ≥ 0.05)") else: - print(f" Result: No statistically significant difference (p ≥ 0.05)") - else: - print("\nCannot perform olmocr vs non-olmocr comparison: Missing candidates") - - # Test 2: Top two olmocr candidates (if there are at least two) - if len(top_two_olmocr) >= 2: - olmocr1_name, olmocr1_score = top_two_olmocr[0][0], top_two_olmocr[0][1] - olmocr2_name, olmocr2_score = top_two_olmocr[1][0], top_two_olmocr[1][1] - olmocr1_scores = top_two_olmocr[0][7] # all_test_scores - olmocr2_scores = top_two_olmocr[1][7] # all_test_scores + print("\nCannot perform olmocr vs non-olmocr comparison: Missing candidates") - diff, p_value = perform_permutation_test( - olmocr1_scores, olmocr2_scores, n_permutations=n_permutations - ) - - print(f"\nComparison 2: Top two olmocr candidates") - print(f" {olmocr1_name} ({olmocr1_score*100:.1f}%) vs {olmocr2_name} ({olmocr2_score*100:.1f}%)") - print(f" Difference: {diff*100:.2f}% (positive means {olmocr1_name} is better)") - print(f" p-value: {p_value:.4f}") - if p_value < 0.05: - print(f" Result: Statistically significant difference (p < 0.05)") + # Test 2: Top two olmocr candidates (if there are at least two) + if len(top_two_olmocr) >= 2: + olmocr1_name, olmocr1_score = top_two_olmocr[0][0], top_two_olmocr[0][1] + olmocr2_name, olmocr2_score = top_two_olmocr[1][0], top_two_olmocr[1][1] + olmocr1_scores = top_two_olmocr[0][7] # all_test_scores + olmocr2_scores = top_two_olmocr[1][7] # all_test_scores + + diff, p_value = perform_permutation_test( + olmocr1_scores, olmocr2_scores + ) + + print(f"\nComparison 2: Top two olmocr candidates") + print(f" {olmocr1_name} ({olmocr1_score*100:.1f}%) vs {olmocr2_name} ({olmocr2_score*100:.1f}%)") + print(f" Difference: {diff*100:.2f}% (positive means {olmocr1_name} is better)") + print(f" p-value: {p_value:.4f}") + if p_value < 0.05: + print(f" Result: Statistically significant difference (p < 0.05)") + else: + print(f" Result: No statistically significant difference (p ≥ 0.05)") else: - print(f" Result: No statistically significant difference (p ≥ 0.05)") - else: - print("\nCannot perform top two olmocr comparison: Not enough olmocr candidates") - - print("=" * 60) + print("\nCannot perform top two olmocr comparison: Not enough olmocr candidates") + + print("=" * 60) if __name__ == "__main__": diff --git a/olmocr/bench/sample_data/dataset.jsonl b/olmocr/bench/sample_data/dataset.jsonl index a675e56..a5ec6f1 100644 --- a/olmocr/bench/sample_data/dataset.jsonl +++ b/olmocr/bench/sample_data/dataset.jsonl @@ -22,6 +22,9 @@ {"pdf": "openstax_caculus_pg_273.pdf", "page": 1, "id": "openstax_caculus_pg_273_minediff_02", "type": "present", "checked": "verified", "text": "Use the graph of the position function to determine the time intervals when the velocity is positive, negative, or zero."} {"pdf": "openstax_caculus_pg_273.pdf", "page": 1, "id": "openstax_caculus_pg_273_minediff_03", "type": "present", "checked": "verified", "text": "Use the graph of the velocity function to determine the time intervals when the acceleration is positive, negative, or zero."} +{"pdf": "openstax_caculus_pg_273.pdf", "page": 1, "id": "openstax_caculus_pg_273_minediff_04", "type": "order", "before": "150.", "after": "157."} +{"pdf": "openstax_caculus_pg_273.pdf", "page": 1, "id": "openstax_caculus_pg_273_minediff_05", "type": "order", "before": "150.", "after": "158."} +{"pdf": "openstax_caculus_pg_273.pdf", "page": 1, "id": "openstax_caculus_pg_273_minediff_06", "type": "order", "before": "150.", "after": "159."} {"pdf": "multi_column_miss.pdf", "page": 1, "id": "multi_column_miss_minediff_01", "type": "present", "checked": "verified", "text": "This report first provides the context and development of CSR; then, from internal company documents, examines how PM came to its own version."} {"pdf": "multi_column_miss.pdf", "page": 1, "id": "multi_column_miss_minediff_02", "type": "present", "checked": "verified", "text": "This paper examines whether a tobacco company espousing CSR should be judged simply as a corporate entity along standards of business ethics, or as an irretrievably negative force in the realm of public health, thereby rendering CSR an oxymoron."} @@ -39,20 +42,18 @@ {"pdf": "olmo2-pg4.pdf", "page": 1, "id": "olmo2-pg4_table08", "type": "table", "cell": "Math proofs code", "left_heading": "Algebraic Stack"} {"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "Quadratic regression", "left": "Challenge"} -{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "Instrument Use", "left": "Normal"} -{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "0.87", "top_heading": "Procedure"} -{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "0.87", "top_heading": "ReACT"} - -{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "Pick-and-place object", "left_heading": "27"} -{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "0.66", "right": "0.44"} - -{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t00", "type": "table", "cell": "Interact with a moving agent", "top_heading": "Unit Test Topic"} +{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t01", "type": "table", "cell": "Instrument Use", "left": "Normal"} +{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t02", "type": "table", "cell": "0.87", "top_heading": "Procedure"} +{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t03", "type": "table", "cell": "0.87", "top_heading": "ReACT"} +{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t04", "type": "table", "cell": "Pick-and-place object", "left_heading": "27"} +{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t05", "type": "table", "cell": "0.66", "right": "0.44"} +{"pdf": "discoverworld_crazy_table4.pdf", "page": 1, "id": "olmo2-discoverworld_crazy_table4_t06", "type": "table", "cell": "Interact with a moving agent", "top_heading": "Unit Test Topic"} {"pdf": "earnings.pdf", "page": 1, "id": "earnings_table00", "type": "table", "cell": "1,136", "top_heading": "Year Ended"} {"pdf": "earnings.pdf", "page": 1, "id": "earnings_table01", "type": "table", "cell": "Year Ended"} {"pdf": "earnings.pdf", "page": 1, "id": "earnings_table02", "type": "table", "cell": "680", "up": "1,892"} -{"pdf": "earnings.pdf", "page": 1, "id": "earnings_table02", "type": "table", "cell": "2,532", "left_heading": "Research and development"} - +{"pdf": "earnings.pdf", "page": 1, "id": "earnings_table03", "type": "table", "cell": "2,532", "left_heading": "Research and development"} +{"pdf": "earnings.pdf", "page": 1, "id": "earnings_table04", "type": "absent", "text": "62"} diff --git a/olmocr/bench/tests.py b/olmocr/bench/tests.py index 3464dad..83a361a 100644 --- a/olmocr/bench/tests.py +++ b/olmocr/bench/tests.py @@ -10,96 +10,7 @@ from typing import List, Optional, Tuple, Dict, Any from fuzzysearch import find_near_matches from rapidfuzz import fuzz - -def parse_markdown_tables(md_content: str) -> List[np.ndarray]: - """ - Extract and parse all markdown tables from the provided content. - - Args: - md_content: The markdown content containing tables - - Returns: - A list of numpy arrays, each representing a parsed table - """ - # Extract all tables from markdown - table_pattern = r'(\|(?:[^|]*\|)+)\s*\n\|(?:[:-]+\|)+\s*\n((?:\|(?:[^|]*\|)+\s*\n)+)' - table_matches = re.finditer(table_pattern, md_content) - - parsed_tables = [] - - for table_match in table_matches: - # Extract header and body from the table match - header_row = table_match.group(1).strip() - body_rows = table_match.group(2).strip().split('\n') - - # Process header and rows to remove leading/trailing | - header_cells = [cell.strip() for cell in header_row.split('|')] - if header_cells[0] == '': - header_cells = header_cells[1:] - if header_cells[-1] == '': - header_cells = header_cells[:-1] - - # Process table body rows - table_data = [] - for row in [header_row] + body_rows: - if '|' not in row: # Skip separator row - continue - - cells = [cell.strip() for cell in row.split('|')] - if cells[0] == '': - cells = cells[1:] - if cells[-1] == '': - cells = cells[:-1] - - table_data.append(cells) - - # Skip separator row (second row with dashes) - if len(table_data) > 1 and all('-' in cell for cell in table_data[1]): - table_data = [table_data[0]] + table_data[2:] - - # Convert to numpy array for easier manipulation - # First ensure all rows have the same number of columns by padding if necessary - max_cols = max(len(row) for row in table_data) - padded_data = [row + [''] * (max_cols - len(row)) for row in table_data] - table_array = np.array(padded_data) - - parsed_tables.append(table_array) - - return parsed_tables - - -def parse_html_tables(html_content: str) -> List[np.ndarray]: - """ - Extract and parse all HTML tables from the provided content. - - Args: - html_content: The HTML content containing tables - - Returns: - A list of numpy arrays, each representing a parsed table - """ - soup = BeautifulSoup(html_content, 'html.parser') - tables = soup.find_all('table') - - parsed_tables = [] - - for table in tables: - rows = table.find_all(['tr']) - table_data = [] - - for row in rows: - cells = row.find_all(['th', 'td']) - row_data = [cell.get_text().strip() for cell in cells] - table_data.append(row_data) - - # Ensure all rows have the same number of columns - if table_data: - max_cols = max(len(row) for row in table_data) - padded_data = [row + [''] * (max_cols - len(row)) for row in table_data] - table_array = np.array(padded_data) - parsed_tables.append(table_array) - - return parsed_tables +from olmocr.repeatdetect import RepeatDetector class TestType(str, Enum): @@ -107,6 +18,7 @@ class TestType(str, Enum): ABSENT = "absent" ORDER = "order" TABLE = "table" + REPEAT = "repeat" class TestChecked(str, Enum): @@ -239,8 +151,95 @@ class TextOrderTest(BasePDFTest): return True, "" return False, (f"Could not find a location where '{self.before[:40]}...' appears before " f"'{self.after[:40]}...'.") +def parse_markdown_tables(md_content: str) -> List[np.ndarray]: + """ + Extract and parse all markdown tables from the provided content. + + Args: + md_content: The markdown content containing tables + + Returns: + A list of numpy arrays, each representing a parsed table + """ + # Extract all tables from markdown + table_pattern = r'(\|(?:[^|]*\|)+)\s*\n\|(?:[:-]+\|)+\s*\n((?:\|(?:[^|]*\|)+\s*\n)+)' + table_matches = re.finditer(table_pattern, md_content) + + parsed_tables = [] + + for table_match in table_matches: + # Extract header and body from the table match + header_row = table_match.group(1).strip() + body_rows = table_match.group(2).strip().split('\n') + + # Process header and rows to remove leading/trailing | + header_cells = [cell.strip() for cell in header_row.split('|')] + if header_cells[0] == '': + header_cells = header_cells[1:] + if header_cells[-1] == '': + header_cells = header_cells[:-1] + + # Process table body rows + table_data = [] + for row in [header_row] + body_rows: + if '|' not in row: # Skip separator row + continue + + cells = [cell.strip() for cell in row.split('|')] + if cells[0] == '': + cells = cells[1:] + if cells[-1] == '': + cells = cells[:-1] + + table_data.append(cells) + + # Skip separator row (second row with dashes) + if len(table_data) > 1 and all('-' in cell for cell in table_data[1]): + table_data = [table_data[0]] + table_data[2:] + + # Convert to numpy array for easier manipulation + # First ensure all rows have the same number of columns by padding if necessary + max_cols = max(len(row) for row in table_data) + padded_data = [row + [''] * (max_cols - len(row)) for row in table_data] + table_array = np.array(padded_data) + + parsed_tables.append(table_array) + + return parsed_tables +def parse_html_tables(html_content: str) -> List[np.ndarray]: + """ + Extract and parse all HTML tables from the provided content. + + Args: + html_content: The HTML content containing tables + + Returns: + A list of numpy arrays, each representing a parsed table + """ + soup = BeautifulSoup(html_content, 'html.parser') + tables = soup.find_all('table') + + parsed_tables = [] + + for table in tables: + rows = table.find_all(['tr']) + table_data = [] + + for row in rows: + cells = row.find_all(['th', 'td']) + row_data = [cell.get_text().strip() for cell in cells] + table_data.append(row_data) + + # Ensure all rows have the same number of columns + if table_data: + max_cols = max(len(row) for row in table_data) + padded_data = [row + [''] * (max_cols - len(row)) for row in table_data] + table_array = np.array(padded_data) + parsed_tables.append(table_array) + + return parsed_tables @dataclass @@ -401,6 +400,23 @@ class TableTest(BasePDFTest): return False, f"Found cells matching '{self.cell}' but relationships were not satisfied: {'; '.join(failed_reasons)}" +@dataclass +class RepetitionTest(BasePDFTest): + max_repeats: int=10 + + def run(self, content: str) -> Tuple[bool, str]: + # Makes sure that the content has no egregious repeated ngrams at the end, which indicate a degradation of quality + d = RepeatDetector(max_ngram_size=5) + d.add_letters(content) + repeats = d.ngram_repeats() + + for index, count in enumerate(repeats): + if count > self.max_repeats: + return False, f"Text ends with {count} repeating {index+1}-grams, invalid" + + return True, "" + + def load_tests(jsonl_file: str) -> List[BasePDFTest]: """ Load tests from a JSONL file. @@ -412,6 +428,7 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]: A list of test objects. """ tests: List[BasePDFTest] = [] + unique_ids = set() with open(jsonl_file, "r") as file: for line_number, line in enumerate(file, start=1): line = line.strip() @@ -430,6 +447,11 @@ def load_tests(jsonl_file: str) -> List[BasePDFTest]: else: raise ValidationError(f"Unknown test type: {test_type}") + if test.id in unique_ids: + raise ValidationError(f"Test with duplicate id {test.id} found, error loading tests.") + else: + unique_ids.add(test.id) + tests.append(test) except json.JSONDecodeError as e: print(f"Error parsing JSON on line {line_number}: {e}") diff --git a/olmocr/repeatdetect.py b/olmocr/repeatdetect.py index 76166e4..80bc110 100644 --- a/olmocr/repeatdetect.py +++ b/olmocr/repeatdetect.py @@ -2,7 +2,7 @@ import random import string import time import unittest - +import re class RepeatDetector: def __init__(self, max_ngram_size: int = 10): @@ -18,20 +18,23 @@ class RepeatDetector: if not self.data: return result + # Normalize all whitespace to single spaces + text = re.sub(r'\s+', ' ', self.data) + # For each n-gram size for size in range(1, self.max_ngram_size + 1): - if len(self.data) < size: + if len(text) < size: continue # Get the last n-gram - target = self.data[-size:] + target = text[-size:] # Count backwards from the end to find repeats count = 0 - pos = len(self.data) - size # Start position for previous n-gram + pos = len(text) - size # Start position for previous n-gram while pos >= 0: - if self.data[pos : pos + size] == target: + if text[pos : pos + size] == target: count += 1 pos -= size # Move back by the size of the n-gram else: