mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 08:12:22 +00:00
Pipeline scales temperature automatically, increases performance ~2%
This commit is contained in:
parent
4768ac4be5
commit
1f8cc59b22
@ -18,29 +18,27 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
from typing import List, Optional
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import pypdf
|
import pypdf
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from olmocr.bench.tests import TextPresenceTest, save_tests
|
from olmocr.bench.tests import TextPresenceTest, save_tests
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||||
from olmocr.filter import PdfFilter
|
from olmocr.filter import PdfFilter
|
||||||
|
|
||||||
|
|
||||||
def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
|
def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Download a PDF file from S3.
|
Download a PDF file from S3.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
s3_path: The S3 path (s3://bucket/path/to/file.pdf)
|
s3_path: The S3 path (s3://bucket/path/to/file.pdf)
|
||||||
local_path: The local path to save the file
|
local_path: The local path to save the file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if download was successful, False otherwise
|
bool: True if download was successful, False otherwise
|
||||||
"""
|
"""
|
||||||
@ -49,13 +47,13 @@ def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
|
|||||||
parts = s3_path.replace("s3://", "").split("/", 1)
|
parts = s3_path.replace("s3://", "").split("/", 1)
|
||||||
bucket = parts[0]
|
bucket = parts[0]
|
||||||
key = parts[1]
|
key = parts[1]
|
||||||
|
|
||||||
# Create S3 client
|
# Create S3 client
|
||||||
s3 = boto3.client("s3")
|
s3 = boto3.client("s3")
|
||||||
|
|
||||||
# Create directory if it doesn't exist
|
# Create directory if it doesn't exist
|
||||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||||
|
|
||||||
# Download file
|
# Download file
|
||||||
s3.download_file(bucket, key, local_path)
|
s3.download_file(bucket, key, local_path)
|
||||||
return True
|
return True
|
||||||
@ -67,35 +65,35 @@ def download_pdf_from_s3(s3_path: str, local_path: str) -> bool:
|
|||||||
def extract_page_from_pdf(input_path: str, output_path: str, page_num: int) -> bool:
|
def extract_page_from_pdf(input_path: str, output_path: str, page_num: int) -> bool:
|
||||||
"""
|
"""
|
||||||
Extract a specific page from a PDF and save it as a new PDF.
|
Extract a specific page from a PDF and save it as a new PDF.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_path: Path to the input PDF
|
input_path: Path to the input PDF
|
||||||
output_path: Path to save the extracted page
|
output_path: Path to save the extracted page
|
||||||
page_num: The page number to extract (0-indexed)
|
page_num: The page number to extract (0-indexed)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if extraction was successful, False otherwise
|
bool: True if extraction was successful, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Ensure output directory exists
|
# Ensure output directory exists
|
||||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
|
||||||
# Read the input PDF
|
# Read the input PDF
|
||||||
reader = pypdf.PdfReader(input_path)
|
reader = pypdf.PdfReader(input_path)
|
||||||
|
|
||||||
# Check if page number is valid
|
# Check if page number is valid
|
||||||
if page_num >= len(reader.pages):
|
if page_num >= len(reader.pages):
|
||||||
print(f"Page number {page_num} out of range for {input_path} with {len(reader.pages)} pages")
|
print(f"Page number {page_num} out of range for {input_path} with {len(reader.pages)} pages")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Create a new PDF with just the selected page
|
# Create a new PDF with just the selected page
|
||||||
writer = pypdf.PdfWriter()
|
writer = pypdf.PdfWriter()
|
||||||
writer.add_page(reader.pages[page_num])
|
writer.add_page(reader.pages[page_num])
|
||||||
|
|
||||||
# Write the output PDF
|
# Write the output PDF
|
||||||
with open(output_path, "wb") as output_file:
|
with open(output_path, "wb") as output_file:
|
||||||
writer.write(output_file)
|
writer.write(output_file)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error extracting page {page_num} from {input_path}: {str(e)}")
|
print(f"Error extracting page {page_num} from {input_path}: {str(e)}")
|
||||||
@ -105,12 +103,12 @@ def extract_page_from_pdf(input_path: str, output_path: str, page_num: int) -> b
|
|||||||
def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Optional[List[str]]:
|
def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Optional[List[str]]:
|
||||||
"""
|
"""
|
||||||
Use Gemini to detect headers and footers in a rendered PDF page.
|
Use Gemini to detect headers and footers in a rendered PDF page.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pdf_path: Path to the PDF file
|
pdf_path: Path to the PDF file
|
||||||
page_num: The page number to analyze (0-indexed)
|
page_num: The page number to analyze (0-indexed)
|
||||||
api_key: Gemini API key
|
api_key: Gemini API key
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[List[str]]: List of detected header/footer texts, or None if detection failed
|
Optional[List[str]]: List of detected header/footer texts, or None if detection failed
|
||||||
"""
|
"""
|
||||||
@ -121,23 +119,21 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
|
|||||||
|
|
||||||
# Render the PDF page as an image
|
# Render the PDF page as an image
|
||||||
try:
|
try:
|
||||||
image_base64 = render_pdf_to_base64png(
|
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num + 1, target_longest_image_dim=2048) # render_pdf_to_base64png is 1-indexed
|
||||||
pdf_path,
|
|
||||||
page_num=page_num + 1, # render_pdf_to_base64png is 1-indexed
|
|
||||||
target_longest_image_dim=2048
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error rendering PDF page: {str(e)}")
|
print(f"Error rendering PDF page: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image_part = types.Part(inline_data=types.Blob(mime_type="image/png", data=base64.b64decode(image_base64)))
|
image_part = types.Part(inline_data=types.Blob(mime_type="image/png", data=base64.b64decode(image_base64)))
|
||||||
|
|
||||||
contents = [
|
contents = [
|
||||||
types.Content(
|
types.Content(
|
||||||
role="user",
|
role="user",
|
||||||
parts=[
|
parts=[
|
||||||
image_part,
|
image_part,
|
||||||
types.Part.from_text(text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""),
|
types.Part.from_text(
|
||||||
|
text="""Please tell me which text in this image is part of any headers/footers and would therefore be skipped it someone were reading it outloud to another person. Include page numbers and document-level headers and footers, but not inner subsections."""
|
||||||
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@ -149,47 +145,38 @@ def detect_headers_footers(pdf_path: str, page_num: int, api_key: str) -> Option
|
|||||||
max_output_tokens=8192,
|
max_output_tokens=8192,
|
||||||
response_mime_type="application/json",
|
response_mime_type="application/json",
|
||||||
response_schema=genai.types.Schema(
|
response_schema=genai.types.Schema(
|
||||||
type = genai.types.Type.OBJECT,
|
type=genai.types.Type.OBJECT,
|
||||||
properties = {
|
properties={
|
||||||
"headers": genai.types.Schema(
|
"headers": genai.types.Schema(
|
||||||
type = genai.types.Type.ARRAY,
|
type=genai.types.Type.ARRAY,
|
||||||
items = genai.types.Schema(
|
items=genai.types.Schema(
|
||||||
type = genai.types.Type.STRING,
|
type=genai.types.Type.STRING,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
"footers": genai.types.Schema(
|
"footers": genai.types.Schema(
|
||||||
type = genai.types.Type.ARRAY,
|
type=genai.types.Type.ARRAY,
|
||||||
items = genai.types.Schema(
|
items=genai.types.Schema(
|
||||||
type = genai.types.Type.STRING,
|
type=genai.types.Type.STRING,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.models.generate_content(model=model,
|
response = client.models.generate_content(model=model, contents=contents, config=generate_content_config)
|
||||||
contents=contents,
|
|
||||||
config=generate_content_config)
|
|
||||||
|
|
||||||
assert len(response.candidates) > 0, "No candidates found"
|
assert len(response.candidates) > 0, "No candidates found"
|
||||||
assert (
|
assert response.candidates[0].finish_reason == types.FinishReason.STOP, "Finish reason was not STOP, likely a processing error or repetition failure"
|
||||||
response.candidates[0].finish_reason == types.FinishReason.STOP
|
|
||||||
), "Finish reason was not STOP, likely a processing error or repetition failure"
|
|
||||||
|
|
||||||
data = json.loads(response.candidates[0].content.parts[0].text)
|
data = json.loads(response.candidates[0].content.parts[0].text)
|
||||||
|
|
||||||
return data.get("headers", []) + data.get("footers", [])
|
return data.get("headers", []) + data.get("footers", [])
|
||||||
|
|
||||||
def process_pdf(
|
|
||||||
s3_path: str,
|
def process_pdf(s3_path: str, temp_dir: str, output_dir: str, api_key: str, tests: List[TextPresenceTest]) -> None:
|
||||||
temp_dir: str,
|
|
||||||
output_dir: str,
|
|
||||||
api_key: str,
|
|
||||||
tests: List[TextPresenceTest]
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Process a single PDF from S3.
|
Process a single PDF from S3.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
s3_path: S3 path to the PDF
|
s3_path: S3 path to the PDF
|
||||||
temp_dir: Directory for temporary files
|
temp_dir: Directory for temporary files
|
||||||
@ -200,7 +187,7 @@ def process_pdf(
|
|||||||
# Extract filename from S3 path
|
# Extract filename from S3 path
|
||||||
pdf_filename = os.path.basename(s3_path)
|
pdf_filename = os.path.basename(s3_path)
|
||||||
local_pdf_path = os.path.join(temp_dir, pdf_filename)
|
local_pdf_path = os.path.join(temp_dir, pdf_filename)
|
||||||
|
|
||||||
# Download PDF from S3
|
# Download PDF from S3
|
||||||
if not download_pdf_from_s3(s3_path, local_pdf_path):
|
if not download_pdf_from_s3(s3_path, local_pdf_path):
|
||||||
return
|
return
|
||||||
@ -210,16 +197,16 @@ def process_pdf(
|
|||||||
if pdf_filter.filter_out_pdf(local_pdf_path):
|
if pdf_filter.filter_out_pdf(local_pdf_path):
|
||||||
print("Filtering out", pdf_filename)
|
print("Filtering out", pdf_filename)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Read the PDF to get the number of pages
|
# Read the PDF to get the number of pages
|
||||||
reader = pypdf.PdfReader(local_pdf_path)
|
reader = pypdf.PdfReader(local_pdf_path)
|
||||||
num_pages = len(reader.pages)
|
num_pages = len(reader.pages)
|
||||||
|
|
||||||
if num_pages == 0:
|
if num_pages == 0:
|
||||||
print(f"PDF {pdf_filename} has no pages")
|
print(f"PDF {pdf_filename} has no pages")
|
||||||
return
|
return
|
||||||
|
|
||||||
all_pages = list(range(len(reader.pages)))
|
all_pages = list(range(len(reader.pages)))
|
||||||
random.shuffle(all_pages)
|
random.shuffle(all_pages)
|
||||||
|
|
||||||
@ -229,20 +216,19 @@ def process_pdf(
|
|||||||
|
|
||||||
# Only stick with headers and footers that have some actual data in them
|
# Only stick with headers and footers that have some actual data in them
|
||||||
header_footer_text = [x for x in header_footer_text if len(x.strip()) > 3]
|
header_footer_text = [x for x in header_footer_text if len(x.strip()) > 3]
|
||||||
|
|
||||||
if not header_footer_text:
|
if not header_footer_text:
|
||||||
print(f"No headers/footers detected in {pdf_filename} page {page_num}")
|
print(f"No headers/footers detected in {pdf_filename} page {page_num}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Extract the page and save to output dir
|
# Extract the page and save to output dir
|
||||||
pdf_basename = os.path.splitext(pdf_filename)[0]
|
pdf_basename = os.path.splitext(pdf_filename)[0]
|
||||||
output_pdf_path = os.path.join(output_dir, "pdfs", f"{pdf_basename}_pg{page_num+1}.pdf")
|
output_pdf_path = os.path.join(output_dir, "pdfs", f"{pdf_basename}_pg{page_num+1}.pdf")
|
||||||
|
|
||||||
extract_page_from_pdf(local_pdf_path, output_pdf_path, page_num)
|
extract_page_from_pdf(local_pdf_path, output_pdf_path, page_num)
|
||||||
|
|
||||||
# TODO Now, process it again to make sure extracted headers/footers don't appear in the main body of the text
|
# TODO Now, process it again to make sure extracted headers/footers don't appear in the main body of the text
|
||||||
|
|
||||||
|
|
||||||
# Create tests for each header/footer text
|
# Create tests for each header/footer text
|
||||||
for i, text in enumerate(header_footer_text):
|
for i, text in enumerate(header_footer_text):
|
||||||
test_id = f"{pdf_basename}_pg{page_num+1}_header_{i:02d}"
|
test_id = f"{pdf_basename}_pg{page_num+1}_header_{i:02d}"
|
||||||
@ -255,10 +241,10 @@ def process_pdf(
|
|||||||
max_diffs=0,
|
max_diffs=0,
|
||||||
)
|
)
|
||||||
tests.append(test)
|
tests.append(test)
|
||||||
|
|
||||||
print(f"Processed {pdf_filename} page {page_num+1}, found {len(header_footer_text)} headers/footers")
|
print(f"Processed {pdf_filename} page {page_num+1}, found {len(header_footer_text)} headers/footers")
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {pdf_filename}: {str(e)}")
|
print(f"Error processing {pdf_filename}: {str(e)}")
|
||||||
finally:
|
finally:
|
||||||
@ -274,39 +260,37 @@ def main():
|
|||||||
parser.add_argument("--api_key", help="Gemini API key (if not provided, will use GEMINI_API_KEY environment variable)")
|
parser.add_argument("--api_key", help="Gemini API key (if not provided, will use GEMINI_API_KEY environment variable)")
|
||||||
parser.add_argument("--temp_dir", default="/tmp/mine_headers_footers", help="Directory for temporary files")
|
parser.add_argument("--temp_dir", default="/tmp/mine_headers_footers", help="Directory for temporary files")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Get API key
|
# Get API key
|
||||||
api_key = args.api_key or os.environ.get("GEMINI_API_KEY")
|
api_key = args.api_key or os.environ.get("GEMINI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
print("Error: Gemini API key not provided. Use --api_key or set GEMINI_API_KEY environment variable.")
|
print("Error: Gemini API key not provided. Use --api_key or set GEMINI_API_KEY environment variable.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create directories
|
# Create directories
|
||||||
os.makedirs(args.temp_dir, exist_ok=True)
|
os.makedirs(args.temp_dir, exist_ok=True)
|
||||||
os.makedirs(os.path.join(args.output_dir, "pdfs"), exist_ok=True)
|
os.makedirs(os.path.join(args.output_dir, "pdfs"), exist_ok=True)
|
||||||
|
|
||||||
# Read input list
|
# Read input list
|
||||||
with open(args.input_list, "r") as f:
|
with open(args.input_list, "r") as f:
|
||||||
s3_paths = [line.strip() for line in f if line.strip()]
|
s3_paths = [line.strip() for line in f if line.strip()]
|
||||||
|
|
||||||
print(f"Found {len(s3_paths)} PDF paths in input list")
|
print(f"Found {len(s3_paths)} PDF paths in input list")
|
||||||
|
|
||||||
# Process each PDF
|
# Process each PDF
|
||||||
tests = []
|
tests = []
|
||||||
for s3_path in tqdm(s3_paths, desc="Processing PDFs"):
|
for s3_path in tqdm(s3_paths, desc="Processing PDFs"):
|
||||||
process_pdf(s3_path, args.temp_dir, args.output_dir, api_key, tests)
|
process_pdf(s3_path, args.temp_dir, args.output_dir, api_key, tests)
|
||||||
|
|
||||||
# Save tests after each PDF to avoid losing data in case of crashes
|
# Save tests after each PDF to avoid losing data in case of crashes
|
||||||
if tests:
|
if tests:
|
||||||
save_tests(tests, os.path.join(args.output_dir, "header_footer_tests.jsonl"))
|
save_tests(tests, os.path.join(args.output_dir, "header_footer_tests.jsonl"))
|
||||||
|
|
||||||
|
|
||||||
if len(tests) > 100:
|
if len(tests) > 100:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
print(f"Saved {len(tests)} tests to {os.path.join(args.output_dir, 'header_footer_tests.jsonl')}")
|
print(f"Saved {len(tests)} tests to {os.path.join(args.output_dir, 'header_footer_tests.jsonl')}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -141,7 +141,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"max_tokens": MAX_TOKENS,
|
"max_tokens": MAX_TOKENS,
|
||||||
"temperature": 0.8,
|
"temperature": 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -213,7 +213,7 @@ async def apost(url, json_data):
|
|||||||
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||||
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||||
MAX_RETRIES = args.max_page_retries
|
MAX_RETRIES = args.max_page_retries
|
||||||
|
TEMPERATURE_BY_ATTEMPT = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
|
||||||
exponential_backoffs = 0
|
exponential_backoffs = 0
|
||||||
local_anchor_text_len = args.target_anchor_text_len
|
local_anchor_text_len = args.target_anchor_text_len
|
||||||
local_image_rotation = 0
|
local_image_rotation = 0
|
||||||
@ -222,6 +222,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
|
|||||||
|
|
||||||
while attempt < MAX_RETRIES:
|
while attempt < MAX_RETRIES:
|
||||||
query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation)
|
query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation)
|
||||||
|
query["temperature"] = TEMPERATURE_BY_ATTEMPT[
|
||||||
|
min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
|
||||||
|
] # Change temperature as number of attempts increases to overcome repetition issues at expense of quality
|
||||||
|
|
||||||
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")
|
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user