From 1f8cc59b226435ab64f97958d6baaf84bbefad1b Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Fri, 14 Mar 2025 22:27:51 -0700 Subject: [PATCH] Pipeline scales temperature automatically, increases performance ~2% --- olmocr/bench/miners/mine_headers_footers.py | 120 +++++++++----------- olmocr/pipeline.py | 7 +- 2 files changed, 57 insertions(+), 70 deletions(-) diff --git a/olmocr/bench/miners/mine_headers_footers.py b/olmocr/bench/miners/mine_headers_footers.py index 3f2e108..3c4ec09 100644 --- a/olmocr/bench/miners/mine_headers_footers.py +++ b/olmocr/bench/miners/mine_headers_footers.py @@ -18,29 +18,27 @@ import base64 import json import os import random -import time -from pathlib import Path -from typing import Dict, List, Optional +from typing import List, Optional import boto3 import pypdf from google import genai from google.genai import types - from tqdm import tqdm from olmocr.bench.tests import TextPresenceTest, save_tests from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.filter import PdfFilter + def download_pdf_from_s3(s3_path: str, local_path: str) -> bool: """ Download a PDF file from S3. - + Args: s3_path: The S3 path (s3://bucket/path/to/file.pdf) local_path: The local path to save the file - + Returns: bool: True if download was successful, False otherwise """ @@ -49,13 +47,13 @@ def download_pdf_from_s3(s3_path: str, local_path: str) -> bool: parts = s3_path.replace("s3://", "").split("/", 1) bucket = parts[0] key = parts[1] - + # Create S3 client s3 = boto3.client("s3") - + # Create directory if it doesn't exist os.makedirs(os.path.dirname(local_path), exist_ok=True) - + # Download file s3.download_file(bucket, key, local_path) return True @@ -67,35 +65,35 @@ def download_pdf_from_s3(s3_path: str, local_path: str) -> bool: def extract_page_from_pdf(input_path: str, output_path: str, page_num: int) -> bool: """ Extract a specific page from a PDF and save it as a new PDF. - + Args: input_path: Path to the input PDF output_path: Path to save the extracted page page_num: The page number to extract (0-indexed) - + Returns: bool: True if extraction was successful, False otherwise """ try: # Ensure output directory exists os.makedirs(os.path.dirname(output_path), exist_ok=True) - + # Read the input PDF reader = pypdf.PdfReader(input_path) - + # Check if page number is valid if page_num >= len(reader.pages): print(f"Page number {page_num} out of range for {input_path} with {len(reader.pages)} pages") return False - + # Create a new PDF with just the selected page writer = pypdf.PdfWriter() writer.add_page(reader.pages[page_num]) - + # Write the output PDF with open(output_path, "wb") as output_file: writer.write(output_file) - + return True except Exception as e: print(f"Error extracting page {page_num} from {input_path}: {str(e)}") @@ -105,12 +103,12 @@ def extract_page_from_pdf(input_path: str, output_path: str, page_num: int) -> b def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Optional[List[str]]: """ Use Gemini to detect headers and footers in a rendered PDF page. - + Args: pdf_path: Path to the PDF file page_num: The page number to analyze (0-indexed) api_key: Gemini API key - + Returns: Optional[List[str]]: List of detected header/footer texts, or None if detection failed """ @@ -121,23 +119,21 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option # Render the PDF page as an image try: - image_base64 = render_pdf_to_base64png( - pdf_path, - page_num=page_num + 1, # render_pdf_to_base64png is 1-indexed - target_longest_image_dim=2048 - ) + image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num + 1, target_longest_image_dim=2048) # render_pdf_to_base64png is 1-indexed except Exception as e: print(f"Error rendering PDF page: {str(e)}") return None - + image_part = types.Part(inline_data=types.Blob(mime_type="image/png", data=base64.b64decode(image_base64))) - + contents = [ types.Content( role="user", parts=[ image_part, - types.Part.from_text(text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""), + types.Part.from_text( + text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections.""" + ), ], ), ] @@ -149,47 +145,38 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option max_output_tokens=8192, response_mime_type="application/json", response_schema=genai.types.Schema( - type = genai.types.Type.OBJECT, - properties = { + type=genai.types.Type.OBJECT, + properties={ "headers": genai.types.Schema( - type = genai.types.Type.ARRAY, - items = genai.types.Schema( - type = genai.types.Type.STRING, + type=genai.types.Type.ARRAY, + items=genai.types.Schema( + type=genai.types.Type.STRING, ), ), "footers": genai.types.Schema( - type = genai.types.Type.ARRAY, - items = genai.types.Schema( - type = genai.types.Type.STRING, + type=genai.types.Type.ARRAY, + items=genai.types.Schema( + type=genai.types.Type.STRING, ), ), }, ), ) - response = client.models.generate_content(model=model, - contents=contents, - config=generate_content_config) - + response = client.models.generate_content(model=model, contents=contents, config=generate_content_config) + assert len(response.candidates) > 0, "No candidates found" - assert ( - response.candidates[0].finish_reason == types.FinishReason.STOP - ), "Finish reason was not STOP, likely a processing error or repetition failure" + assert response.candidates[0].finish_reason == types.FinishReason.STOP, "Finish reason was not STOP, likely a processing error or repetition failure" data = json.loads(response.candidates[0].content.parts[0].text) return data.get("headers", []) + data.get("footers", []) -def process_pdf( - s3_path: str, - temp_dir: str, - output_dir: str, - api_key: str, - tests: List[TextPresenceTest] -) -> None: + +def process_pdf(s3_path: str, temp_dir: str, output_dir: str, api_key: str, tests: List[TextPresenceTest]) -> None: """ Process a single PDF from S3. - + Args: s3_path: S3 path to the PDF temp_dir: Directory for temporary files @@ -200,7 +187,7 @@ def process_pdf( # Extract filename from S3 path pdf_filename = os.path.basename(s3_path) local_pdf_path = os.path.join(temp_dir, pdf_filename) - + # Download PDF from S3 if not download_pdf_from_s3(s3_path, local_pdf_path): return @@ -210,16 +197,16 @@ def process_pdf( if pdf_filter.filter_out_pdf(local_pdf_path): print("Filtering out", pdf_filename) return - + try: # Read the PDF to get the number of pages reader = pypdf.PdfReader(local_pdf_path) num_pages = len(reader.pages) - + if num_pages == 0: print(f"PDF {pdf_filename} has no pages") return - + all_pages = list(range(len(reader.pages))) random.shuffle(all_pages) @@ -229,20 +216,19 @@ def process_pdf( # Only stick with headers and footers that have some actual data in them header_footer_text = [x for x in header_footer_text if len(x.strip()) > 3] - + if not header_footer_text: print(f"No headers/footers detected in {pdf_filename} page {page_num}") continue - + # Extract the page and save to output dir pdf_basename = os.path.splitext(pdf_filename)[0] output_pdf_path = os.path.join(output_dir, "pdfs", f"{pdf_basename}_pg{page_num+1}.pdf") - + extract_page_from_pdf(local_pdf_path, output_pdf_path, page_num) # TODO Now, process it again to make sure extracted headers/footers don't appear in the main body of the text - # Create tests for each header/footer text for i, text in enumerate(header_footer_text): test_id = f"{pdf_basename}_pg{page_num+1}_header_{i:02d}" @@ -255,10 +241,10 @@ def process_pdf( max_diffs=0, ) tests.append(test) - + print(f"Processed {pdf_filename} page {page_num+1}, found {len(header_footer_text)} headers/footers") return - + except Exception as e: print(f"Error processing {pdf_filename}: {str(e)}") finally: @@ -274,39 +260,37 @@ def main(): parser.add_argument("--api_key", help="Gemini API key (if not provided, will use GEMINI_API_KEY environment variable)") parser.add_argument("--temp_dir", default="/tmp/mine_headers_footers", help="Directory for temporary files") args = parser.parse_args() - + # Get API key api_key = args.api_key or os.environ.get("GEMINI_API_KEY") if not api_key: print("Error: Gemini API key not provided. Use --api_key or set GEMINI_API_KEY environment variable.") return - + # Create directories os.makedirs(args.temp_dir, exist_ok=True) os.makedirs(os.path.join(args.output_dir, "pdfs"), exist_ok=True) - + # Read input list with open(args.input_list, "r") as f: s3_paths = [line.strip() for line in f if line.strip()] - + print(f"Found {len(s3_paths)} PDF paths in input list") - + # Process each PDF tests = [] for s3_path in tqdm(s3_paths, desc="Processing PDFs"): process_pdf(s3_path, args.temp_dir, args.output_dir, api_key, tests) - + # Save tests after each PDF to avoid losing data in case of crashes if tests: save_tests(tests, os.path.join(args.output_dir, "header_footer_tests.jsonl")) - if len(tests) > 100: break - print(f"Saved {len(tests)} tests to {os.path.join(args.output_dir, 'header_footer_tests.jsonl')}") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 7d98b29..fd75bec 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -141,7 +141,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_ } ], "max_tokens": MAX_TOKENS, - "temperature": 0.8, + "temperature": 0.0, } @@ -213,7 +213,7 @@ async def apost(url, json_data): async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult: COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions" MAX_RETRIES = args.max_page_retries - + TEMPERATURE_BY_ATTEMPT = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] exponential_backoffs = 0 local_anchor_text_len = args.target_anchor_text_len local_image_rotation = 0 @@ -222,6 +222,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: while attempt < MAX_RETRIES: query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation) + query["temperature"] = TEMPERATURE_BY_ATTEMPT[ + min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1) + ] # Change temperature as number of attempts increases to overcome repetition issues at expense of quality logger.info(f"Built page query for {pdf_orig_path}-{page_num}")