Pipeline scales temperature automatically, increases performance ~2%

This commit is contained in:
Jake Poznanski 2025-03-14 22:27:51 -07:00
parent 4768ac4be5
commit 1f8cc59b22
2 changed files with 57 additions and 70 deletions

View File

@ -18,29 +18,27 @@ import base64
import json import json
import os import os
import random import random
import time from typing import List, Optional
from pathlib import Path
from typing import Dict, List, Optional
import boto3 import boto3
import pypdf import pypdf
from google import genai from google import genai
from google.genai import types from google.genai import types
from tqdm import tqdm from tqdm import tqdm
from olmocr.bench.tests import TextPresenceTest, save_tests from olmocr.bench.tests import TextPresenceTest, save_tests
from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.filter import PdfFilter from olmocr.filter import PdfFilter
def download_pdf_from_s3(s3_path: str, local_path: str) -> bool: def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
""" """
Download a PDF file from S3. Download a PDF file from S3.
Args: Args:
s3_path: The S3 path (s3://bucket/path/to/file.pdf) s3_path: The S3 path (s3://bucket/path/to/file.pdf)
local_path: The local path to save the file local_path: The local path to save the file
Returns: Returns:
bool: True if download was successful, False otherwise bool: True if download was successful, False otherwise
""" """
@ -49,13 +47,13 @@ def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
parts = s3_path.replace("s3://", "").split("/", 1) parts = s3_path.replace("s3://", "").split("/", 1)
bucket = parts[0] bucket = parts[0]
key = parts[1] key = parts[1]
# Create S3 client # Create S3 client
s3 = boto3.client("s3") s3 = boto3.client("s3")
# Create directory if it doesn't exist # Create directory if it doesn't exist
os.makedirs(os.path.dirname(local_path), exist_ok=True) os.makedirs(os.path.dirname(local_path), exist_ok=True)
# Download file # Download file
s3.download_file(bucket, key, local_path) s3.download_file(bucket, key, local_path)
return True return True
@ -67,35 +65,35 @@ def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
def extract_page_from_pdf(input_path: str, output_path: str, page_num: int) -> bool: def extract_page_from_pdf(input_path: str, output_path: str, page_num: int) -> bool:
""" """
Extract a specific page from a PDF and save it as a new PDF. Extract a specific page from a PDF and save it as a new PDF.
Args: Args:
input_path: Path to the input PDF input_path: Path to the input PDF
output_path: Path to save the extracted page output_path: Path to save the extracted page
page_num: The page number to extract (0-indexed) page_num: The page number to extract (0-indexed)
Returns: Returns:
bool: True if extraction was successful, False otherwise bool: True if extraction was successful, False otherwise
""" """
try: try:
# Ensure output directory exists # Ensure output directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True) os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Read the input PDF # Read the input PDF
reader = pypdf.PdfReader(input_path) reader = pypdf.PdfReader(input_path)
# Check if page number is valid # Check if page number is valid
if page_num >= len(reader.pages): if page_num >= len(reader.pages):
print(f"Page number {page_num} out of range for {input_path} with {len(reader.pages)} pages") print(f"Page number {page_num} out of range for {input_path} with {len(reader.pages)} pages")
return False return False
# Create a new PDF with just the selected page # Create a new PDF with just the selected page
writer = pypdf.PdfWriter() writer = pypdf.PdfWriter()
writer.add_page(reader.pages[page_num]) writer.add_page(reader.pages[page_num])
# Write the output PDF # Write the output PDF
with open(output_path, "wb") as output_file: with open(output_path, "wb") as output_file:
writer.write(output_file) writer.write(output_file)
return True return True
except Exception as e: except Exception as e:
print(f"Error extracting page {page_num} from {input_path}: {str(e)}") print(f"Error extracting page {page_num} from {input_path}: {str(e)}")
@ -105,12 +103,12 @@ def extract_page_from_pdf(input_path: str, output_path: str, page_num: int) -> b
def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Optional[List[str]]: def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Optional[List[str]]:
""" """
Use Gemini to detect headers and footers in a rendered PDF page. Use Gemini to detect headers and footers in a rendered PDF page.
Args: Args:
pdf_path: Path to the PDF file pdf_path: Path to the PDF file
page_num: The page number to analyze (0-indexed) page_num: The page number to analyze (0-indexed)
api_key: Gemini API key api_key: Gemini API key
Returns: Returns:
Optional[List[str]]: List of detected header/footer texts, or None if detection failed Optional[List[str]]: List of detected header/footer texts, or None if detection failed
""" """
@ -121,23 +119,21 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
# Render the PDF page as an image # Render the PDF page as an image
try: try:
image_base64 = render_pdf_to_base64png( image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num + 1, target_longest_image_dim=2048) # render_pdf_to_base64png is 1-indexed
pdf_path,
page_num=page_num + 1, # render_pdf_to_base64png is 1-indexed
target_longest_image_dim=2048
)
except Exception as e: except Exception as e:
print(f"Error rendering PDF page: {str(e)}") print(f"Error rendering PDF page: {str(e)}")
return None return None
image_part = types.Part(inline_data=types.Blob(mime_type="image/png", data=base64.b64decode(image_base64))) image_part = types.Part(inline_data=types.Blob(mime_type="image/png", data=base64.b64decode(image_base64)))
contents = [ contents = [
types.Content( types.Content(
role="user", role="user",
parts=[ parts=[
image_part, image_part,
types.Part.from_text(text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""), types.Part.from_text(
text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""
),
], ],
), ),
] ]
@ -149,47 +145,38 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
max_output_tokens=8192, max_output_tokens=8192,
response_mime_type="application/json", response_mime_type="application/json",
response_schema=genai.types.Schema( response_schema=genai.types.Schema(
type = genai.types.Type.OBJECT, type=genai.types.Type.OBJECT,
properties = { properties={
"headers": genai.types.Schema( "headers": genai.types.Schema(
type = genai.types.Type.ARRAY, type=genai.types.Type.ARRAY,
items = genai.types.Schema( items=genai.types.Schema(
type = genai.types.Type.STRING, type=genai.types.Type.STRING,
), ),
), ),
"footers": genai.types.Schema( "footers": genai.types.Schema(
type = genai.types.Type.ARRAY, type=genai.types.Type.ARRAY,
items = genai.types.Schema( items=genai.types.Schema(
type = genai.types.Type.STRING, type=genai.types.Type.STRING,
), ),
), ),
}, },
), ),
) )
response = client.models.generate_content(model=model, response = client.models.generate_content(model=model, contents=contents, config=generate_content_config)
contents=contents,
config=generate_content_config)
assert len(response.candidates) > 0, "No candidates found" assert len(response.candidates) > 0, "No candidates found"
assert ( assert response.candidates[0].finish_reason == types.FinishReason.STOP, "Finish reason was not STOP, likely a processing error or repetition failure"
response.candidates[0].finish_reason == types.FinishReason.STOP
), "Finish reason was not STOP, likely a processing error or repetition failure"
data = json.loads(response.candidates[0].content.parts[0].text) data = json.loads(response.candidates[0].content.parts[0].text)
return data.get("headers", []) + data.get("footers", []) return data.get("headers", []) + data.get("footers", [])
def process_pdf(
s3_path: str, def process_pdf(s3_path: str, temp_dir: str, output_dir: str, api_key: str, tests: List[TextPresenceTest]) -> None:
temp_dir: str,
output_dir: str,
api_key: str,
tests: List[TextPresenceTest]
) -> None:
""" """
Process a single PDF from S3. Process a single PDF from S3.
Args: Args:
s3_path: S3 path to the PDF s3_path: S3 path to the PDF
temp_dir: Directory for temporary files temp_dir: Directory for temporary files
@ -200,7 +187,7 @@ def process_pdf(
# Extract filename from S3 path # Extract filename from S3 path
pdf_filename = os.path.basename(s3_path) pdf_filename = os.path.basename(s3_path)
local_pdf_path = os.path.join(temp_dir, pdf_filename) local_pdf_path = os.path.join(temp_dir, pdf_filename)
# Download PDF from S3 # Download PDF from S3
if not download_pdf_from_s3(s3_path, local_pdf_path): if not download_pdf_from_s3(s3_path, local_pdf_path):
return return
@ -210,16 +197,16 @@ def process_pdf(
if pdf_filter.filter_out_pdf(local_pdf_path): if pdf_filter.filter_out_pdf(local_pdf_path):
print("Filtering out", pdf_filename) print("Filtering out", pdf_filename)
return return
try: try:
# Read the PDF to get the number of pages # Read the PDF to get the number of pages
reader = pypdf.PdfReader(local_pdf_path) reader = pypdf.PdfReader(local_pdf_path)
num_pages = len(reader.pages) num_pages = len(reader.pages)
if num_pages == 0: if num_pages == 0:
print(f"PDF {pdf_filename} has no pages") print(f"PDF {pdf_filename} has no pages")
return return
all_pages = list(range(len(reader.pages))) all_pages = list(range(len(reader.pages)))
random.shuffle(all_pages) random.shuffle(all_pages)
@ -229,20 +216,19 @@ def process_pdf(
# Only stick with headers and footers that have some actual data in them # Only stick with headers and footers that have some actual data in them
header_footer_text = [x for x in header_footer_text if len(x.strip()) > 3] header_footer_text = [x for x in header_footer_text if len(x.strip()) > 3]
if not header_footer_text: if not header_footer_text:
print(f"No headers/footers detected in {pdf_filename} page {page_num}") print(f"No headers/footers detected in {pdf_filename} page {page_num}")
continue continue
# Extract the page and save to output dir # Extract the page and save to output dir
pdf_basename = os.path.splitext(pdf_filename)[0] pdf_basename = os.path.splitext(pdf_filename)[0]
output_pdf_path = os.path.join(output_dir, "pdfs", f"{pdf_basename}_pg{page_num+1}.pdf") output_pdf_path = os.path.join(output_dir, "pdfs", f"{pdf_basename}_pg{page_num+1}.pdf")
extract_page_from_pdf(local_pdf_path, output_pdf_path, page_num) extract_page_from_pdf(local_pdf_path, output_pdf_path, page_num)
# TODO Now, process it again to make sure extracted headers/footers don't appear in the main body of the text # TODO Now, process it again to make sure extracted headers/footers don't appear in the main body of the text
# Create tests for each header/footer text # Create tests for each header/footer text
for i, text in enumerate(header_footer_text): for i, text in enumerate(header_footer_text):
test_id = f"{pdf_basename}_pg{page_num+1}_header_{i:02d}" test_id = f"{pdf_basename}_pg{page_num+1}_header_{i:02d}"
@ -255,10 +241,10 @@ def process_pdf(
max_diffs=0, max_diffs=0,
) )
tests.append(test) tests.append(test)
print(f"Processed {pdf_filename} page {page_num+1}, found {len(header_footer_text)} headers/footers") print(f"Processed {pdf_filename} page {page_num+1}, found {len(header_footer_text)} headers/footers")
return return
except Exception as e: except Exception as e:
print(f"Error processing {pdf_filename}: {str(e)}") print(f"Error processing {pdf_filename}: {str(e)}")
finally: finally:
@ -274,39 +260,37 @@ def main():
parser.add_argument("--api_key", help="Gemini API key (if not provided, will use GEMINI_API_KEY environment variable)") parser.add_argument("--api_key", help="Gemini API key (if not provided, will use GEMINI_API_KEY environment variable)")
parser.add_argument("--temp_dir", default="/tmp/mine_headers_footers", help="Directory for temporary files") parser.add_argument("--temp_dir", default="/tmp/mine_headers_footers", help="Directory for temporary files")
args = parser.parse_args() args = parser.parse_args()
# Get API key # Get API key
api_key = args.api_key or os.environ.get("GEMINI_API_KEY") api_key = args.api_key or os.environ.get("GEMINI_API_KEY")
if not api_key: if not api_key:
print("Error: Gemini API key not provided. Use --api_key or set GEMINI_API_KEY environment variable.") print("Error: Gemini API key not provided. Use --api_key or set GEMINI_API_KEY environment variable.")
return return
# Create directories # Create directories
os.makedirs(args.temp_dir, exist_ok=True) os.makedirs(args.temp_dir, exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "pdfs"), exist_ok=True) os.makedirs(os.path.join(args.output_dir, "pdfs"), exist_ok=True)
# Read input list # Read input list
with open(args.input_list, "r") as f: with open(args.input_list, "r") as f:
s3_paths = [line.strip() for line in f if line.strip()] s3_paths = [line.strip() for line in f if line.strip()]
print(f"Found {len(s3_paths)} PDF paths in input list") print(f"Found {len(s3_paths)} PDF paths in input list")
# Process each PDF # Process each PDF
tests = [] tests = []
for s3_path in tqdm(s3_paths, desc="Processing PDFs"): for s3_path in tqdm(s3_paths, desc="Processing PDFs"):
process_pdf(s3_path, args.temp_dir, args.output_dir, api_key, tests) process_pdf(s3_path, args.temp_dir, args.output_dir, api_key, tests)
# Save tests after each PDF to avoid losing data in case of crashes # Save tests after each PDF to avoid losing data in case of crashes
if tests: if tests:
save_tests(tests, os.path.join(args.output_dir, "header_footer_tests.jsonl")) save_tests(tests, os.path.join(args.output_dir, "header_footer_tests.jsonl"))
if len(tests) > 100: if len(tests) > 100:
break break
print(f"Saved {len(tests)} tests to {os.path.join(args.output_dir, 'header_footer_tests.jsonl')}") print(f"Saved {len(tests)} tests to {os.path.join(args.output_dir, 'header_footer_tests.jsonl')}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -141,7 +141,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
} }
], ],
"max_tokens": MAX_TOKENS, "max_tokens": MAX_TOKENS,
"temperature": 0.8, "temperature": 0.0,
} }
@ -213,7 +213,7 @@ async def apost(url, json_data):
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult: async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions" COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries MAX_RETRIES = args.max_page_retries
TEMPERATURE_BY_ATTEMPT = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
exponential_backoffs = 0 exponential_backoffs = 0
local_anchor_text_len = args.target_anchor_text_len local_anchor_text_len = args.target_anchor_text_len
local_image_rotation = 0 local_image_rotation = 0
@ -222,6 +222,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
while attempt < MAX_RETRIES: while attempt < MAX_RETRIES:
query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation) query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation)
query["temperature"] = TEMPERATURE_BY_ATTEMPT[
min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
] # Change temperature as number of attempts increases to overcome repetition issues at expense of quality
logger.info(f"Built page query for {pdf_orig_path}-{page_num}") logger.info(f"Built page query for {pdf_orig_path}-{page_num}")