mirror of
https://github.com/allenai/olmocr.git
synced 2026-01-08 13:22:25 +00:00
More fixes to data gen script
This commit is contained in:
parent
becbfdc62d
commit
c32dced59c
@ -22,6 +22,7 @@ from olmocr.data.renderpdf import (
|
||||
get_png_dimensions_from_base64,
|
||||
render_pdf_to_base64png,
|
||||
)
|
||||
from olmocr.filter.filter import PdfFilter, Language
|
||||
|
||||
|
||||
def download_s3_pdf(s3_path, local_path):
|
||||
@ -527,7 +528,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
|
||||
tests.append(
|
||||
{
|
||||
"pdf": pdf_filename,
|
||||
"page": page_num,
|
||||
"page": 1,
|
||||
"id": f"{pdf_id}_{element_type}_{uuid.uuid4().hex[:8]}",
|
||||
"type": TestType.ABSENT.value,
|
||||
"text": text,
|
||||
@ -554,7 +555,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
|
||||
tests.append(
|
||||
{
|
||||
"pdf": pdf_filename,
|
||||
"page": page_num,
|
||||
"page": 1,
|
||||
"id": f"{pdf_id}_page_number_{uuid.uuid4().hex[:8]}",
|
||||
"type": TestType.ABSENT.value,
|
||||
"text": page_number_text,
|
||||
@ -608,7 +609,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
|
||||
# Create a TableTest with relevant relationships
|
||||
test_data = {
|
||||
"pdf": pdf_filename,
|
||||
"page": page_num,
|
||||
"page": 1,
|
||||
"id": f"{pdf_id}_table{table_idx}_{uuid.uuid4().hex[:8]}",
|
||||
"type": TestType.TABLE.value,
|
||||
"cell": cell_text,
|
||||
@ -753,7 +754,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
|
||||
tests.append(
|
||||
{
|
||||
"pdf": pdf_filename,
|
||||
"page": page_num,
|
||||
"page": 1,
|
||||
"id": f"{pdf_id}_order_{uuid.uuid4().hex[:8]}",
|
||||
"type": TestType.ORDER.value,
|
||||
"before": first_sentence,
|
||||
@ -838,7 +839,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
|
||||
return unique_tests
|
||||
|
||||
|
||||
def process_pdf(pdf_info, args, client):
|
||||
def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
"""Process a single PDF, render a random page, and create an HTML template."""
|
||||
s3_path, index = pdf_info
|
||||
|
||||
@ -855,6 +856,11 @@ def process_pdf(pdf_info, args, client):
|
||||
if not download_s3_pdf(s3_path, local_pdf_path):
|
||||
print(f"Failed to download PDF from {s3_path}")
|
||||
return None
|
||||
|
||||
# Apply filter if enabled
|
||||
if pdf_filter and pdf_filter.filter_out_pdf(local_pdf_path):
|
||||
print(f"PDF filtered out: {s3_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get page count using pypdf
|
||||
@ -880,10 +886,16 @@ def process_pdf(pdf_info, args, client):
|
||||
# Create output directories
|
||||
html_dir = os.path.join(args.output_dir, "html")
|
||||
pdfs_dir = os.path.join(args.output_dir, "pdfs")
|
||||
markdown_dir = os.path.join(args.output_dir, "markdown")
|
||||
training_dir = os.path.join(args.output_dir, "training")
|
||||
bench_data_dir = os.path.join(args.output_dir, "bench_data")
|
||||
bench_synthetic_dir = os.path.join(bench_data_dir, "pdfs", "synthetic")
|
||||
claude_original_dir = os.path.join(bench_data_dir, "claude_original", "synthetic")
|
||||
os.makedirs(html_dir, exist_ok=True)
|
||||
os.makedirs(pdfs_dir, exist_ok=True)
|
||||
os.makedirs(markdown_dir, exist_ok=True)
|
||||
os.makedirs(training_dir, exist_ok=True)
|
||||
os.makedirs(bench_data_dir, exist_ok=True)
|
||||
os.makedirs(bench_synthetic_dir, exist_ok=True)
|
||||
os.makedirs(claude_original_dir, exist_ok=True)
|
||||
|
||||
# Save HTML to output directory
|
||||
html_path = os.path.join(html_dir, f"{pdf_id}_page{page_num}.html")
|
||||
@ -892,9 +904,28 @@ def process_pdf(pdf_info, args, client):
|
||||
|
||||
# Convert HTML to markdown and save
|
||||
markdown_content = html_to_markdown(html_content)
|
||||
markdown_path = os.path.join(markdown_dir, f"{pdf_id}_page{page_num}.md")
|
||||
markdown_filename = f"{pdf_id}_page{page_num}.md"
|
||||
markdown_path = os.path.join(training_dir, markdown_filename)
|
||||
with open(markdown_path, "w") as f:
|
||||
f.write(markdown_content)
|
||||
|
||||
# Create soft link to PDF in training directory
|
||||
pdf_link_name = f"{pdf_id}_page{page_num}.pdf"
|
||||
pdf_link_path = os.path.join(training_dir, pdf_link_name)
|
||||
# Remove existing link if it exists
|
||||
if os.path.exists(pdf_link_path) or os.path.islink(pdf_link_path):
|
||||
os.remove(pdf_link_path)
|
||||
# Create relative symlink from training to pdfs directory
|
||||
os.symlink(os.path.relpath(os.path.join(pdfs_dir, f"{pdf_id}_page{page_num}.pdf"), training_dir), pdf_link_path)
|
||||
|
||||
# Create soft link to markdown in claude_original/synthetic with new naming scheme
|
||||
claude_md_link_name = f"{pdf_id}_page{page_num}_pg1_repeat1.md"
|
||||
claude_md_link_path = os.path.join(claude_original_dir, claude_md_link_name)
|
||||
# Remove existing link if it exists
|
||||
if os.path.exists(claude_md_link_path) or os.path.islink(claude_md_link_path):
|
||||
os.remove(claude_md_link_path)
|
||||
# Create relative symlink from claude_original/synthetic to training directory
|
||||
os.symlink(os.path.relpath(markdown_path, claude_original_dir), claude_md_link_path)
|
||||
|
||||
# Extract the page and save as PDF
|
||||
original_pdf_path = os.path.join(pdfs_dir, f"{pdf_id}_page{page_num}_original.pdf")
|
||||
@ -929,14 +960,23 @@ def process_pdf(pdf_info, args, client):
|
||||
# If playwright rendering failed and was required, return None to skip this test
|
||||
if not args.skip_playwright and not render_success:
|
||||
return None
|
||||
|
||||
# Create soft link in bench_data/synthetic/ directory
|
||||
if playwright_pdf_path:
|
||||
synthetic_link_path = os.path.join(bench_synthetic_dir, playwright_pdf_filename)
|
||||
# Remove existing link if it exists
|
||||
if os.path.exists(synthetic_link_path) or os.path.islink(synthetic_link_path):
|
||||
os.remove(synthetic_link_path)
|
||||
# Create relative symlink from bench_data/synthetic to pdfs directory
|
||||
os.symlink(os.path.relpath(playwright_pdf_path, bench_synthetic_dir), synthetic_link_path)
|
||||
|
||||
# Generate tests from the HTML content
|
||||
# Use the playwright rendered PDF path for tests
|
||||
tests = generate_tests_from_html(html_content, pdf_id, page_num, verbose_table_testing)
|
||||
|
||||
# Update the PDF path in all tests to use the playwright rendered PDF
|
||||
# Update the PDF path in all tests to use the playwright rendered PDF with synthetic/ prefix
|
||||
for test in tests:
|
||||
test["pdf"] = playwright_pdf_filename
|
||||
test["pdf"] = f"synthetic/{playwright_pdf_filename}"
|
||||
|
||||
# Log table test stats if verbose
|
||||
if verbose_table_testing:
|
||||
@ -973,6 +1013,7 @@ def main():
|
||||
parser.add_argument("--api_key", help="Claude API key (or set ANTHROPIC_API_KEY environment variable)")
|
||||
parser.add_argument("--skip_playwright", action="store_true", help="Skip Playwright PDF rendering")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose output including table test verification")
|
||||
parser.add_argument("--filter", action="store_true", help="Apply PDF filtering to remove forms, spam, and non-English content")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Ensure output and temp directories exist
|
||||
@ -987,6 +1028,16 @@ def main():
|
||||
|
||||
# Initialize Claude client
|
||||
client = Anthropic(api_key=api_key)
|
||||
|
||||
# Initialize PDF filter if enabled
|
||||
pdf_filter = None
|
||||
if args.filter:
|
||||
pdf_filter = PdfFilter(
|
||||
languages_to_keep={Language.ENGLISH, None}, # None means could not detect language, that's okay keep it, might be an OCR
|
||||
apply_download_spam_check=True,
|
||||
apply_form_check=True,
|
||||
)
|
||||
print("PDF filtering enabled")
|
||||
|
||||
# Reservoir sampling implementation
|
||||
s3_paths = []
|
||||
@ -1010,8 +1061,10 @@ def main():
|
||||
random.shuffle(s3_paths)
|
||||
s3_paths = s3_paths[: args.max_tests]
|
||||
|
||||
# Initialize synthetic.json as a JSONL file (empty initially)
|
||||
synthetic_json_path = os.path.join(args.output_dir, "synthetic.jsonl")
|
||||
# Initialize synthetic.jsonl in bench_data folder as a JSONL file (empty initially)
|
||||
bench_data_dir = os.path.join(args.output_dir, "bench_data")
|
||||
os.makedirs(bench_data_dir, exist_ok=True)
|
||||
synthetic_json_path = os.path.join(bench_data_dir, "synthetic.jsonl")
|
||||
open(synthetic_json_path, "w").close() # Create empty file
|
||||
|
||||
# Counter for test statistics
|
||||
@ -1027,7 +1080,7 @@ def main():
|
||||
# Process PDFs in parallel
|
||||
with ThreadPoolExecutor(max_workers=args.parallel) as executor:
|
||||
# Submit all tasks
|
||||
futures = {executor.submit(process_pdf, (s3_path, i), args, client): s3_path for i, s3_path in enumerate(s3_paths)}
|
||||
futures = {executor.submit(process_pdf, (s3_path, i), args, client, pdf_filter): s3_path for i, s3_path in enumerate(s3_paths)}
|
||||
|
||||
# Process results as they complete
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing PDFs"):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user