mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 16:52:20 +00:00
Adding stricter math and table tests when in synthetic mode
This commit is contained in:
parent
1197c35808
commit
a00d9d172e
@ -729,6 +729,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, rand
|
|||||||
"type": TestType.TABLE.value,
|
"type": TestType.TABLE.value,
|
||||||
"cell": cell_text,
|
"cell": cell_text,
|
||||||
"max_diffs": 0,
|
"max_diffs": 0,
|
||||||
|
"ignore_markdown_tables": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check cell up
|
# Check cell up
|
||||||
@ -948,6 +949,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, rand
|
|||||||
"type": "math",
|
"type": "math",
|
||||||
"math": equation,
|
"math": equation,
|
||||||
"max_diffs": 0,
|
"max_diffs": 0,
|
||||||
|
"ignore_dollar_delimited": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -633,6 +633,8 @@ class TableTest(BasePDFTest):
|
|||||||
top_heading: str = ""
|
top_heading: str = ""
|
||||||
left_heading: str = ""
|
left_heading: str = ""
|
||||||
|
|
||||||
|
ignore_markdown_tables: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
if self.type != TestType.TABLE.value:
|
if self.type != TestType.TABLE.value:
|
||||||
@ -670,6 +672,7 @@ class TableTest(BasePDFTest):
|
|||||||
threshold = max(0.5, threshold)
|
threshold = max(0.5, threshold)
|
||||||
|
|
||||||
# Parse tables based on content_type
|
# Parse tables based on content_type
|
||||||
|
if not self.ignore_markdown_tables:
|
||||||
md_tables = parse_markdown_tables(content)
|
md_tables = parse_markdown_tables(content)
|
||||||
tables_to_check.extend(md_tables)
|
tables_to_check.extend(md_tables)
|
||||||
|
|
||||||
@ -926,6 +929,8 @@ class BaselineTest(BasePDFTest):
|
|||||||
class MathTest(BasePDFTest):
|
class MathTest(BasePDFTest):
|
||||||
math: str
|
math: str
|
||||||
|
|
||||||
|
ignore_dollar_delimited: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
if self.type != TestType.MATH.value:
|
if self.type != TestType.MATH.value:
|
||||||
@ -941,12 +946,16 @@ class MathTest(BasePDFTest):
|
|||||||
def run(self, content: str) -> Tuple[bool, str]:
|
def run(self, content: str) -> Tuple[bool, str]:
|
||||||
# Store both the search pattern and the full pattern to replace
|
# Store both the search pattern and the full pattern to replace
|
||||||
patterns = [
|
patterns = [
|
||||||
(r"\$\$(.+?)\$\$", r"\$\$(.+?)\$\$"), # $$...$$
|
|
||||||
(r"\\\((.+?)\\\)", r"\\\((.+?)\\\)"), # \(...\)
|
(r"\\\((.+?)\\\)", r"\\\((.+?)\\\)"), # \(...\)
|
||||||
(r"\\\[(.+?)\\\]", r"\\\[(.+?)\\\]"), # \[...\]
|
(r"\\\[(.+?)\\\]", r"\\\[(.+?)\\\]"), # \[...\]
|
||||||
(r"\$(.+?)\$", r"\$(.+?)\$"), # $...$
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if not self.ignore_dollar_delimited:
|
||||||
|
patterns.extend([
|
||||||
|
(r"\$\$(.+?)\$\$", r"\$\$(.+?)\$\$"), # $$...$$
|
||||||
|
(r"\$(.+?)\$", r"\$(.+?)\$"), # $...$])
|
||||||
|
])
|
||||||
|
|
||||||
equations = []
|
equations = []
|
||||||
modified_content = content
|
modified_content = content
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ import sqlite3
|
|||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def get_bench_urls(bench_data_dir):
|
def get_bench_urls(bench_data_dir):
|
||||||
@ -70,7 +71,125 @@ def local_path_to_short_hash(local_path):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_contamination(bench_data_dir, metadata_jsonl_path, sqlite_db_path):
|
def find_and_handle_contaminated_files(metadata_jsonl_path, contaminated_pdf_ids, delete_mode=False):
|
||||||
|
"""Find and optionally delete files related to contaminated PDFs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of files that were deleted or would be deleted
|
||||||
|
"""
|
||||||
|
# Get the base directory from metadata jsonl path
|
||||||
|
metadata_dir = Path(metadata_jsonl_path).parent
|
||||||
|
output_dir = metadata_dir.parent # Go up one level from metadata directory
|
||||||
|
|
||||||
|
# Get the name from the metadata jsonl filename (e.g., "synthetic" from "synthetic.jsonl")
|
||||||
|
name = Path(metadata_jsonl_path).stem
|
||||||
|
|
||||||
|
files_to_delete = []
|
||||||
|
|
||||||
|
for pdf_id in contaminated_pdf_ids:
|
||||||
|
# Pattern for files related to this pdf_id
|
||||||
|
# Based on mine_html_templates.py, the files are named with pattern:
|
||||||
|
# {pdf_id}_page{page_num}.{extension}
|
||||||
|
|
||||||
|
# Find HTML files
|
||||||
|
html_dir = output_dir / "html" / name
|
||||||
|
if html_dir.exists():
|
||||||
|
for html_file in html_dir.glob(f"{pdf_id}_page*.html"):
|
||||||
|
files_to_delete.append(html_file)
|
||||||
|
|
||||||
|
# Find PDF files (both original and rendered)
|
||||||
|
pdfs_dir = output_dir / "pdfs" / name
|
||||||
|
if pdfs_dir.exists():
|
||||||
|
for pdf_file in pdfs_dir.glob(f"{pdf_id}_page*.pdf"):
|
||||||
|
files_to_delete.append(pdf_file)
|
||||||
|
|
||||||
|
# Find markdown files in training directory
|
||||||
|
training_dir = output_dir / "training" / name
|
||||||
|
if training_dir.exists():
|
||||||
|
for md_file in training_dir.glob(f"{pdf_id}_page*.md"):
|
||||||
|
files_to_delete.append(md_file)
|
||||||
|
# Also check for PDF symlinks
|
||||||
|
for pdf_link in training_dir.glob(f"{pdf_id}_page*.pdf"):
|
||||||
|
files_to_delete.append(pdf_link)
|
||||||
|
|
||||||
|
# Find files in bench_data directory
|
||||||
|
bench_data_dir = output_dir / "bench_data"
|
||||||
|
|
||||||
|
# Check synthetic PDFs subdirectory
|
||||||
|
bench_synthetic_dir = bench_data_dir / "pdfs" / name
|
||||||
|
if bench_synthetic_dir.exists():
|
||||||
|
for pdf_file in bench_synthetic_dir.glob(f"{pdf_id}_page*.pdf"):
|
||||||
|
files_to_delete.append(pdf_file)
|
||||||
|
|
||||||
|
# Check claude_original subdirectory
|
||||||
|
claude_original_dir = bench_data_dir / "claude_original" / name
|
||||||
|
if claude_original_dir.exists():
|
||||||
|
for md_file in claude_original_dir.glob(f"{pdf_id}_page*.md"):
|
||||||
|
files_to_delete.append(md_file)
|
||||||
|
|
||||||
|
# Remove tests from bench_data JSONL file
|
||||||
|
jsonl_file = bench_data_dir / f"{name}.jsonl"
|
||||||
|
if jsonl_file.exists():
|
||||||
|
# Read all tests
|
||||||
|
remaining_tests = []
|
||||||
|
removed_tests = 0
|
||||||
|
|
||||||
|
with open(jsonl_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
test = json.loads(line)
|
||||||
|
# Check if this test belongs to a contaminated PDF
|
||||||
|
# Test PDFs are in format "{name}/{pdf_id}_page{page_num}.pdf"
|
||||||
|
test_pdf = test.get('pdf', '')
|
||||||
|
is_contaminated = False
|
||||||
|
for pdf_id in contaminated_pdf_ids:
|
||||||
|
if f"{pdf_id}_page" in test_pdf:
|
||||||
|
is_contaminated = True
|
||||||
|
removed_tests += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if not is_contaminated:
|
||||||
|
remaining_tests.append(test)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if removed_tests > 0:
|
||||||
|
if delete_mode:
|
||||||
|
# Rewrite the file without contaminated tests
|
||||||
|
with open(jsonl_file, 'w') as f:
|
||||||
|
for test in remaining_tests:
|
||||||
|
f.write(json.dumps(test) + '\n')
|
||||||
|
print(f"Removed {removed_tests} tests from {jsonl_file}")
|
||||||
|
else:
|
||||||
|
print(f"Would remove {removed_tests} tests from {jsonl_file}")
|
||||||
|
|
||||||
|
# Print summary of files to delete
|
||||||
|
if files_to_delete:
|
||||||
|
print(f"\n{'Deleting' if delete_mode else 'Would delete'} {len(files_to_delete)} files:")
|
||||||
|
for file_path in sorted(files_to_delete): # Show first 10
|
||||||
|
relative_path = file_path.relative_to(output_dir) if output_dir in file_path.parents else file_path
|
||||||
|
print(f" - {relative_path}")
|
||||||
|
|
||||||
|
# Actually delete if in delete mode
|
||||||
|
if delete_mode:
|
||||||
|
try:
|
||||||
|
if file_path.is_symlink() or file_path.exists():
|
||||||
|
file_path.unlink()
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error deleting: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if delete_mode:
|
||||||
|
print(f"\nSuccessfully deleted {len(files_to_delete)} files")
|
||||||
|
else:
|
||||||
|
print(f"\nTo actually delete these files, run with --delete flag")
|
||||||
|
else:
|
||||||
|
print("\nNo files found to delete")
|
||||||
|
|
||||||
|
return files_to_delete
|
||||||
|
|
||||||
|
|
||||||
|
def check_contamination(bench_data_dir, metadata_jsonl_path, sqlite_db_path, delete_mode=False):
|
||||||
"""Main function to check for contamination between bench data and training data."""
|
"""Main function to check for contamination between bench data and training data."""
|
||||||
print(f"Checking contamination...")
|
print(f"Checking contamination...")
|
||||||
print(f"Bench data directory: {bench_data_dir}")
|
print(f"Bench data directory: {bench_data_dir}")
|
||||||
@ -174,25 +293,84 @@ def check_contamination(bench_data_dir, metadata_jsonl_path, sqlite_db_path):
|
|||||||
print("Step 4: Checking for contamination...")
|
print("Step 4: Checking for contamination...")
|
||||||
contaminated_urls = bench_urls.intersection(real_urls)
|
contaminated_urls = bench_urls.intersection(real_urls)
|
||||||
|
|
||||||
|
# Track which PDF IDs are contaminated (including those with blank URLs)
|
||||||
|
contaminated_pdf_ids = set()
|
||||||
|
|
||||||
|
# Add PDF IDs with blank URLs to contaminated set
|
||||||
|
for entry in blank_url_entries:
|
||||||
|
pdf_id = entry.get('pdf_id', 'N/A')
|
||||||
|
if pdf_id != 'N/A':
|
||||||
|
contaminated_pdf_ids.add(pdf_id)
|
||||||
|
|
||||||
if contaminated_urls:
|
if contaminated_urls:
|
||||||
print(f"\n⚠️ CONTAMINATION DETECTED! Found {len(contaminated_urls)} matching URLs:")
|
# Find the pdf_ids that correspond to contaminated URLs
|
||||||
for url in sorted(contaminated_urls)[:10]: # Show first 10
|
for metadata_entry in metadata_entries:
|
||||||
|
source_url = metadata_entry.get('source_url')
|
||||||
|
pdf_id = metadata_entry.get('pdf_id', 'N/A')
|
||||||
|
pdf_hash = None
|
||||||
|
|
||||||
|
# Process URL to get hash
|
||||||
|
if source_url.startswith("s3://"):
|
||||||
|
pdf_hash = s3_url_to_hash(source_url)
|
||||||
|
elif source_url.startswith("./"):
|
||||||
|
short_hash = local_path_to_short_hash(source_url)
|
||||||
|
if short_hash:
|
||||||
|
conn_temp = sqlite3.connect(sqlite_db_path)
|
||||||
|
cursor_temp = conn_temp.cursor()
|
||||||
|
cursor_temp.execute("SELECT full_hash FROM substr_to_full_hash WHERE pdf_hash = ?", (short_hash,))
|
||||||
|
result = cursor_temp.fetchone()
|
||||||
|
if result:
|
||||||
|
pdf_hash = result[0]
|
||||||
|
conn_temp.close()
|
||||||
|
|
||||||
|
# If we have a hash, look up the real URI
|
||||||
|
if pdf_hash:
|
||||||
|
conn_temp = sqlite3.connect(sqlite_db_path)
|
||||||
|
cursor_temp = conn_temp.cursor()
|
||||||
|
cursor_temp.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
||||||
|
result = cursor_temp.fetchone()
|
||||||
|
conn_temp.close()
|
||||||
|
|
||||||
|
if result and result[0] and result[0] in contaminated_urls:
|
||||||
|
contaminated_pdf_ids.add(pdf_id)
|
||||||
|
|
||||||
|
# Check if we have any contamination (URL matches or blank URLs)
|
||||||
|
total_contaminated = len(contaminated_urls) + len(blank_url_entries)
|
||||||
|
|
||||||
|
if total_contaminated > 0:
|
||||||
|
print(f"\n⚠️ CONTAMINATION DETECTED!")
|
||||||
|
if contaminated_urls:
|
||||||
|
print(f" - Found {len(contaminated_urls)} matching URLs")
|
||||||
|
if blank_url_entries:
|
||||||
|
print(f" - Found {len(blank_url_entries)} entries with blank URLs (treated as contaminated)")
|
||||||
|
print(f" - Total contaminated PDF IDs: {len(contaminated_pdf_ids)}")
|
||||||
|
|
||||||
|
if contaminated_urls:
|
||||||
|
print(f"\nMatching URLs (first 10):")
|
||||||
|
for url in sorted(contaminated_urls)[:10]:
|
||||||
print(f" - {url}")
|
print(f" - {url}")
|
||||||
if len(contaminated_urls) > 10:
|
if len(contaminated_urls) > 10:
|
||||||
print(f" ... and {len(contaminated_urls) - 10} more")
|
print(f" ... and {len(contaminated_urls) - 10} more")
|
||||||
|
|
||||||
|
# Handle file deletion/dry run
|
||||||
|
if contaminated_pdf_ids:
|
||||||
|
print(f"\nProcessing files for {len(contaminated_pdf_ids)} contaminated PDFs...")
|
||||||
|
find_and_handle_contaminated_files(metadata_jsonl_path, contaminated_pdf_ids, delete_mode)
|
||||||
else:
|
else:
|
||||||
print("\n✅ No contamination detected. Bench URLs and training URLs are disjoint.")
|
print("\n✅ No contamination detected. Bench URLs and training URLs are disjoint, and no blank URLs found.")
|
||||||
|
|
||||||
# Print summary statistics
|
# Print summary statistics
|
||||||
print(f"\nSummary:")
|
print(f"\nSummary:")
|
||||||
print(f" Bench URLs: {len(bench_urls)}")
|
print(f" Bench URLs: {len(bench_urls)}")
|
||||||
print(f" Training URLs (mapped): {len(real_urls)}")
|
print(f" Training URLs (mapped): {len(real_urls)}")
|
||||||
print(f" Contaminated URLs: {len(contaminated_urls)}")
|
print(f" Contaminated URLs: {len(contaminated_urls)}")
|
||||||
|
print(f" Blank URL entries: {len(blank_url_entries)}")
|
||||||
|
print(f" Total contaminated: {total_contaminated}")
|
||||||
if bench_urls:
|
if bench_urls:
|
||||||
contamination_rate = (len(contaminated_urls) / len(bench_urls)) * 100
|
contamination_rate = (len(contaminated_urls) / len(bench_urls)) * 100
|
||||||
print(f" Contamination rate: {contamination_rate:.2f}%")
|
print(f" Contamination rate: {contamination_rate:.2f}%")
|
||||||
|
|
||||||
return len(contaminated_urls)
|
return total_contaminated
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -211,6 +389,11 @@ def main():
|
|||||||
"sqlite_db",
|
"sqlite_db",
|
||||||
help="Path to SQLite database with pdf_mapping table"
|
help="Path to SQLite database with pdf_mapping table"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--delete",
|
||||||
|
action="store_true",
|
||||||
|
help="Delete contaminated files (default is dry run)"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -231,7 +414,8 @@ def main():
|
|||||||
contaminated_count = check_contamination(
|
contaminated_count = check_contamination(
|
||||||
args.bench_data_dir,
|
args.bench_data_dir,
|
||||||
args.metadata_jsonl,
|
args.metadata_jsonl,
|
||||||
args.sqlite_db
|
args.sqlite_db,
|
||||||
|
delete_mode=args.delete
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return non-zero exit code if contamination found
|
# Return non-zero exit code if contamination found
|
||||||
|
Loading…
x
Reference in New Issue
Block a user