Adding stricter math and table tests when in synthetic mode

This commit is contained in:
Jake Poznanski 2025-09-23 18:37:50 +00:00
parent 1197c35808
commit a00d9d172e
3 changed files with 212 additions and 17 deletions

View File

@ -729,6 +729,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, rand
"type": TestType.TABLE.value, "type": TestType.TABLE.value,
"cell": cell_text, "cell": cell_text,
"max_diffs": 0, "max_diffs": 0,
"ignore_markdown_tables": True,
} }
# Check cell up # Check cell up
@ -948,6 +949,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, rand
"type": "math", "type": "math",
"math": equation, "math": equation,
"max_diffs": 0, "max_diffs": 0,
"ignore_dollar_delimited": True,
} }
) )

View File

@ -633,6 +633,8 @@ class TableTest(BasePDFTest):
top_heading: str = "" top_heading: str = ""
left_heading: str = "" left_heading: str = ""
ignore_markdown_tables: bool = False
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
if self.type != TestType.TABLE.value: if self.type != TestType.TABLE.value:
@ -670,6 +672,7 @@ class TableTest(BasePDFTest):
threshold = max(0.5, threshold) threshold = max(0.5, threshold)
# Parse tables based on content_type # Parse tables based on content_type
if not self.ignore_markdown_tables:
md_tables = parse_markdown_tables(content) md_tables = parse_markdown_tables(content)
tables_to_check.extend(md_tables) tables_to_check.extend(md_tables)
@ -926,6 +929,8 @@ class BaselineTest(BasePDFTest):
class MathTest(BasePDFTest): class MathTest(BasePDFTest):
math: str math: str
ignore_dollar_delimited: bool = False
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
if self.type != TestType.MATH.value: if self.type != TestType.MATH.value:
@ -941,12 +946,16 @@ class MathTest(BasePDFTest):
def run(self, content: str) -> Tuple[bool, str]: def run(self, content: str) -> Tuple[bool, str]:
# Store both the search pattern and the full pattern to replace # Store both the search pattern and the full pattern to replace
patterns = [ patterns = [
(r"\$\$(.+?)\$\$", r"\$\$(.+?)\$\$"), # $$...$$
(r"\\\((.+?)\\\)", r"\\\((.+?)\\\)"), # \(...\) (r"\\\((.+?)\\\)", r"\\\((.+?)\\\)"), # \(...\)
(r"\\\[(.+?)\\\]", r"\\\[(.+?)\\\]"), # \[...\] (r"\\\[(.+?)\\\]", r"\\\[(.+?)\\\]"), # \[...\]
(r"\$(.+?)\$", r"\$(.+?)\$"), # $...$
] ]
if not self.ignore_dollar_delimited:
patterns.extend([
(r"\$\$(.+?)\$\$", r"\$\$(.+?)\$\$"), # $$...$$
(r"\$(.+?)\$", r"\$(.+?)\$"), # $...$])
])
equations = [] equations = []
modified_content = content modified_content = content

View File

@ -28,6 +28,7 @@ import sqlite3
import argparse import argparse
from pathlib import Path from pathlib import Path
import re import re
import os
def get_bench_urls(bench_data_dir): def get_bench_urls(bench_data_dir):
@ -70,7 +71,125 @@ def local_path_to_short_hash(local_path):
return None return None
def check_contamination(bench_data_dir, metadata_jsonl_path, sqlite_db_path): def find_and_handle_contaminated_files(metadata_jsonl_path, contaminated_pdf_ids, delete_mode=False):
"""Find and optionally delete files related to contaminated PDFs.
Returns:
List of files that were deleted or would be deleted
"""
# Get the base directory from metadata jsonl path
metadata_dir = Path(metadata_jsonl_path).parent
output_dir = metadata_dir.parent # Go up one level from metadata directory
# Get the name from the metadata jsonl filename (e.g., "synthetic" from "synthetic.jsonl")
name = Path(metadata_jsonl_path).stem
files_to_delete = []
for pdf_id in contaminated_pdf_ids:
# Pattern for files related to this pdf_id
# Based on mine_html_templates.py, the files are named with pattern:
# {pdf_id}_page{page_num}.{extension}
# Find HTML files
html_dir = output_dir / "html" / name
if html_dir.exists():
for html_file in html_dir.glob(f"{pdf_id}_page*.html"):
files_to_delete.append(html_file)
# Find PDF files (both original and rendered)
pdfs_dir = output_dir / "pdfs" / name
if pdfs_dir.exists():
for pdf_file in pdfs_dir.glob(f"{pdf_id}_page*.pdf"):
files_to_delete.append(pdf_file)
# Find markdown files in training directory
training_dir = output_dir / "training" / name
if training_dir.exists():
for md_file in training_dir.glob(f"{pdf_id}_page*.md"):
files_to_delete.append(md_file)
# Also check for PDF symlinks
for pdf_link in training_dir.glob(f"{pdf_id}_page*.pdf"):
files_to_delete.append(pdf_link)
# Find files in bench_data directory
bench_data_dir = output_dir / "bench_data"
# Check synthetic PDFs subdirectory
bench_synthetic_dir = bench_data_dir / "pdfs" / name
if bench_synthetic_dir.exists():
for pdf_file in bench_synthetic_dir.glob(f"{pdf_id}_page*.pdf"):
files_to_delete.append(pdf_file)
# Check claude_original subdirectory
claude_original_dir = bench_data_dir / "claude_original" / name
if claude_original_dir.exists():
for md_file in claude_original_dir.glob(f"{pdf_id}_page*.md"):
files_to_delete.append(md_file)
# Remove tests from bench_data JSONL file
jsonl_file = bench_data_dir / f"{name}.jsonl"
if jsonl_file.exists():
# Read all tests
remaining_tests = []
removed_tests = 0
with open(jsonl_file, 'r') as f:
for line in f:
try:
test = json.loads(line)
# Check if this test belongs to a contaminated PDF
# Test PDFs are in format "{name}/{pdf_id}_page{page_num}.pdf"
test_pdf = test.get('pdf', '')
is_contaminated = False
for pdf_id in contaminated_pdf_ids:
if f"{pdf_id}_page" in test_pdf:
is_contaminated = True
removed_tests += 1
break
if not is_contaminated:
remaining_tests.append(test)
except json.JSONDecodeError:
continue
if removed_tests > 0:
if delete_mode:
# Rewrite the file without contaminated tests
with open(jsonl_file, 'w') as f:
for test in remaining_tests:
f.write(json.dumps(test) + '\n')
print(f"Removed {removed_tests} tests from {jsonl_file}")
else:
print(f"Would remove {removed_tests} tests from {jsonl_file}")
# Print summary of files to delete
if files_to_delete:
print(f"\n{'Deleting' if delete_mode else 'Would delete'} {len(files_to_delete)} files:")
for file_path in sorted(files_to_delete): # Show first 10
relative_path = file_path.relative_to(output_dir) if output_dir in file_path.parents else file_path
print(f" - {relative_path}")
# Actually delete if in delete mode
if delete_mode:
try:
if file_path.is_symlink() or file_path.exists():
file_path.unlink()
except Exception as e:
print(f" Error deleting: {e}")
if delete_mode:
print(f"\nSuccessfully deleted {len(files_to_delete)} files")
else:
print(f"\nTo actually delete these files, run with --delete flag")
else:
print("\nNo files found to delete")
return files_to_delete
def check_contamination(bench_data_dir, metadata_jsonl_path, sqlite_db_path, delete_mode=False):
"""Main function to check for contamination between bench data and training data.""" """Main function to check for contamination between bench data and training data."""
print(f"Checking contamination...") print(f"Checking contamination...")
print(f"Bench data directory: {bench_data_dir}") print(f"Bench data directory: {bench_data_dir}")
@ -174,25 +293,84 @@ def check_contamination(bench_data_dir, metadata_jsonl_path, sqlite_db_path):
print("Step 4: Checking for contamination...") print("Step 4: Checking for contamination...")
contaminated_urls = bench_urls.intersection(real_urls) contaminated_urls = bench_urls.intersection(real_urls)
# Track which PDF IDs are contaminated (including those with blank URLs)
contaminated_pdf_ids = set()
# Add PDF IDs with blank URLs to contaminated set
for entry in blank_url_entries:
pdf_id = entry.get('pdf_id', 'N/A')
if pdf_id != 'N/A':
contaminated_pdf_ids.add(pdf_id)
if contaminated_urls: if contaminated_urls:
print(f"\n⚠️ CONTAMINATION DETECTED! Found {len(contaminated_urls)} matching URLs:") # Find the pdf_ids that correspond to contaminated URLs
for url in sorted(contaminated_urls)[:10]: # Show first 10 for metadata_entry in metadata_entries:
source_url = metadata_entry.get('source_url')
pdf_id = metadata_entry.get('pdf_id', 'N/A')
pdf_hash = None
# Process URL to get hash
if source_url.startswith("s3://"):
pdf_hash = s3_url_to_hash(source_url)
elif source_url.startswith("./"):
short_hash = local_path_to_short_hash(source_url)
if short_hash:
conn_temp = sqlite3.connect(sqlite_db_path)
cursor_temp = conn_temp.cursor()
cursor_temp.execute("SELECT full_hash FROM substr_to_full_hash WHERE pdf_hash = ?", (short_hash,))
result = cursor_temp.fetchone()
if result:
pdf_hash = result[0]
conn_temp.close()
# If we have a hash, look up the real URI
if pdf_hash:
conn_temp = sqlite3.connect(sqlite_db_path)
cursor_temp = conn_temp.cursor()
cursor_temp.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
result = cursor_temp.fetchone()
conn_temp.close()
if result and result[0] and result[0] in contaminated_urls:
contaminated_pdf_ids.add(pdf_id)
# Check if we have any contamination (URL matches or blank URLs)
total_contaminated = len(contaminated_urls) + len(blank_url_entries)
if total_contaminated > 0:
print(f"\n⚠️ CONTAMINATION DETECTED!")
if contaminated_urls:
print(f" - Found {len(contaminated_urls)} matching URLs")
if blank_url_entries:
print(f" - Found {len(blank_url_entries)} entries with blank URLs (treated as contaminated)")
print(f" - Total contaminated PDF IDs: {len(contaminated_pdf_ids)}")
if contaminated_urls:
print(f"\nMatching URLs (first 10):")
for url in sorted(contaminated_urls)[:10]:
print(f" - {url}") print(f" - {url}")
if len(contaminated_urls) > 10: if len(contaminated_urls) > 10:
print(f" ... and {len(contaminated_urls) - 10} more") print(f" ... and {len(contaminated_urls) - 10} more")
# Handle file deletion/dry run
if contaminated_pdf_ids:
print(f"\nProcessing files for {len(contaminated_pdf_ids)} contaminated PDFs...")
find_and_handle_contaminated_files(metadata_jsonl_path, contaminated_pdf_ids, delete_mode)
else: else:
print("\n✅ No contamination detected. Bench URLs and training URLs are disjoint.") print("\n✅ No contamination detected. Bench URLs and training URLs are disjoint, and no blank URLs found.")
# Print summary statistics # Print summary statistics
print(f"\nSummary:") print(f"\nSummary:")
print(f" Bench URLs: {len(bench_urls)}") print(f" Bench URLs: {len(bench_urls)}")
print(f" Training URLs (mapped): {len(real_urls)}") print(f" Training URLs (mapped): {len(real_urls)}")
print(f" Contaminated URLs: {len(contaminated_urls)}") print(f" Contaminated URLs: {len(contaminated_urls)}")
print(f" Blank URL entries: {len(blank_url_entries)}")
print(f" Total contaminated: {total_contaminated}")
if bench_urls: if bench_urls:
contamination_rate = (len(contaminated_urls) / len(bench_urls)) * 100 contamination_rate = (len(contaminated_urls) / len(bench_urls)) * 100
print(f" Contamination rate: {contamination_rate:.2f}%") print(f" Contamination rate: {contamination_rate:.2f}%")
return len(contaminated_urls) return total_contaminated
def main(): def main():
@ -211,6 +389,11 @@ def main():
"sqlite_db", "sqlite_db",
help="Path to SQLite database with pdf_mapping table" help="Path to SQLite database with pdf_mapping table"
) )
parser.add_argument(
"--delete",
action="store_true",
help="Delete contaminated files (default is dry run)"
)
args = parser.parse_args() args = parser.parse_args()
@ -231,7 +414,8 @@ def main():
contaminated_count = check_contamination( contaminated_count = check_contamination(
args.bench_data_dir, args.bench_data_dir,
args.metadata_jsonl, args.metadata_jsonl,
args.sqlite_db args.sqlite_db,
delete_mode=args.delete
) )
# Return non-zero exit code if contamination found # Return non-zero exit code if contamination found