mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-17 21:32:05 +00:00
Cleaning up some unused code
This commit is contained in:
parent
d8c13d05f6
commit
c6062677aa
@ -1,862 +0,0 @@
|
||||
import os
|
||||
import hashlib
|
||||
import boto3
|
||||
import sqlite3
|
||||
import orjson
|
||||
import argparse
|
||||
import base64
|
||||
import tempfile
|
||||
import datetime
|
||||
import posixpath
|
||||
import threading
|
||||
import logging
|
||||
import psutil
|
||||
import boto3.session
|
||||
import urllib3.exceptions
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pypdf import PdfReader
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from typing import Optional, List, Tuple, Dict, Callable, Any
|
||||
from urllib.parse import urlparse
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts import build_finetuning_prompt, PageResponse
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.s3_utils import parse_custom_id, expand_s3_glob, get_s3_bytes, parse_s3_path
|
||||
from olmocr.check import check_poppler_version
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# File handler for DEBUG level and above with line-by-line flushing
|
||||
class FlushFileHandler(logging.FileHandler):
|
||||
def emit(self, record):
|
||||
super().emit(record)
|
||||
self.flush() # Explicitly flush after every log entry
|
||||
|
||||
file_handler = FlushFileHandler('birrpipeline-debug.log', mode='a')
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
||||
|
||||
# Add handlers to the logger
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# Global s3 client for the whole script, feel free to adjust params if you need it
|
||||
workspace_s3 = boto3.client('s3')
|
||||
pdf_s3 = boto3.client('s3')
|
||||
|
||||
# Quiet logs from pypdf
|
||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
@dataclass(frozen=True)
|
||||
class BatchInferenceRecord:
|
||||
inference_s3_path: str
|
||||
pdf_s3_path: str
|
||||
page_num: int # 1 indexed!
|
||||
round: int
|
||||
start_index: int
|
||||
length: int
|
||||
finish_reason: str
|
||||
error: Optional[str]
|
||||
|
||||
def is_usable(self):
|
||||
return self.error is None and self.finish_reason == "stop"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PDFRecord:
|
||||
s3_path: str
|
||||
num_pages: int
|
||||
status: str
|
||||
|
||||
def __init__(self, s3_workspace: str, skip_init: bool=False):
|
||||
cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest()
|
||||
home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', cache_key)
|
||||
os.makedirs(home_cache_dir, exist_ok=True)
|
||||
self.db_path = os.path.join(home_cache_dir, 'index.db')
|
||||
|
||||
self.conn = sqlite3.connect(self.db_path)
|
||||
# Enable WAL mode so you can read and write concurrently
|
||||
self.cursor = self.conn.cursor()
|
||||
self.cursor.execute("PRAGMA journal_mode=WAL;")
|
||||
|
||||
if not skip_init:
|
||||
self._initialize_tables()
|
||||
|
||||
def _initialize_tables(self):
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS page_results (
|
||||
inference_s3_path TEXT,
|
||||
pdf_s3_path TEXT,
|
||||
page_num INTEGER,
|
||||
round INTEGER,
|
||||
start_index BIGINT,
|
||||
length BIGINT,
|
||||
finish_reason TEXT,
|
||||
error TEXT
|
||||
)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_path ON page_results(pdf_s3_path)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_inf_path ON page_results(inference_s3_path)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS pdfs (
|
||||
s3_path TEXT PRIMARY KEY,
|
||||
num_pages INTEGER,
|
||||
status TEXT DEFAULT 'pending'
|
||||
)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS processed_files (
|
||||
s3_path TEXT PRIMARY KEY,
|
||||
etag TEXT
|
||||
)
|
||||
""")
|
||||
# Generic metadata such as current round
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS metadata (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[str]:
|
||||
self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,))
|
||||
result = self.cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
def set_metadata(self, key: str, value: str) -> None:
|
||||
self.cursor.execute("""
|
||||
INSERT INTO metadata (key, value)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(key) DO UPDATE SET value=excluded.value
|
||||
""", (key, value))
|
||||
self.conn.commit()
|
||||
|
||||
def is_file_processed(self, s3_path, etag):
|
||||
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
|
||||
result = self.cursor.fetchone()
|
||||
return result is not None and result[0] == etag
|
||||
|
||||
def update_processed_file(self, s3_path, etag):
|
||||
self.cursor.execute("""
|
||||
INSERT INTO processed_files (s3_path, etag)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
|
||||
""", (s3_path, etag))
|
||||
self.conn.commit()
|
||||
|
||||
def clear_index(self):
|
||||
self.cursor.execute("""
|
||||
DELETE FROM processed_files;
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
DELETE FROM page_results;
|
||||
""")
|
||||
self.conn.commit()
|
||||
|
||||
def add_index_entries(self, index_entries: List['BatchInferenceRecord']):
|
||||
if index_entries:
|
||||
self.cursor.executemany("""
|
||||
INSERT INTO page_results (inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.round, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
|
||||
self.conn.commit()
|
||||
|
||||
def get_index_entries(self, pdf_s3_path: str) -> List['BatchInferenceRecord']:
|
||||
self.cursor.execute("""
|
||||
SELECT inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error
|
||||
FROM page_results
|
||||
WHERE pdf_s3_path = ?
|
||||
ORDER BY inference_s3_path DESC, start_index ASC, page_num ASC
|
||||
""", (pdf_s3_path,))
|
||||
|
||||
rows = self.cursor.fetchall()
|
||||
|
||||
return [
|
||||
self.BatchInferenceRecord(
|
||||
inference_s3_path=row[0],
|
||||
pdf_s3_path=row[1],
|
||||
page_num=row[2],
|
||||
round=row[3],
|
||||
start_index=row[4],
|
||||
length=row[5],
|
||||
finish_reason=row[6],
|
||||
error=row[7]
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def delete_index_entries_by_inference_s3_path(self, inference_s3_path: str):
|
||||
self.cursor.execute("DELETE FROM page_results WHERE inference_s3_path = ?", (inference_s3_path,))
|
||||
self.conn.commit()
|
||||
|
||||
def get_last_indexed_round(self) -> int:
|
||||
self.cursor.execute("""
|
||||
SELECT MAX(round)
|
||||
FROM page_results
|
||||
""")
|
||||
|
||||
result = self.cursor.fetchone()
|
||||
return -1 if result[0] is None else result[0]
|
||||
|
||||
def pdf_exists(self, s3_path: str) -> bool:
|
||||
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
|
||||
return self.cursor.fetchone() is not None
|
||||
|
||||
def add_pdf(self, s3_path: str, num_pages: int, status: str = 'pending') -> None:
|
||||
try:
|
||||
self.cursor.execute("""
|
||||
INSERT INTO pdfs (s3_path, num_pages, status)
|
||||
VALUES (?, ?, ?)
|
||||
""", (s3_path, num_pages, status))
|
||||
self.conn.commit()
|
||||
except sqlite3.IntegrityError:
|
||||
logger.warning(f"PDF with s3_path '{s3_path}' already exists.")
|
||||
|
||||
def update_pdf_statuses(self, status_updates: Dict[str, str]) -> None:
|
||||
"""
|
||||
Update the status of multiple PDFs in the database.
|
||||
|
||||
:param status_updates: A dictionary where each key is an s3_path (str) and
|
||||
each value is the new status (str) for that PDF.
|
||||
"""
|
||||
self.cursor.executemany("""
|
||||
UPDATE pdfs
|
||||
SET status = ?
|
||||
WHERE s3_path = ?
|
||||
""", [(new_status, s3_path) for s3_path, new_status in status_updates.items()])
|
||||
self.conn.commit()
|
||||
|
||||
def get_pdf(self, s3_path: str) -> Optional['PDFRecord']:
|
||||
self.cursor.execute("""
|
||||
SELECT s3_path, num_pages, status
|
||||
FROM pdfs
|
||||
WHERE s3_path = ?
|
||||
""", (s3_path,))
|
||||
|
||||
row = self.cursor.fetchone()
|
||||
|
||||
if row:
|
||||
return self.PDFRecord(
|
||||
s3_path=row[0],
|
||||
num_pages=row[1],
|
||||
status=row[2]
|
||||
)
|
||||
return None
|
||||
|
||||
def get_pdfs_by_status(self, status: str) -> List['PDFRecord']:
|
||||
self.cursor.execute("""
|
||||
SELECT s3_path, num_pages, status
|
||||
FROM pdfs
|
||||
WHERE status == ?
|
||||
ORDER BY s3_path DESC, num_pages DESC
|
||||
""", (status, ))
|
||||
|
||||
rows = self.cursor.fetchall()
|
||||
|
||||
return [
|
||||
self.PDFRecord(
|
||||
s3_path=row[0],
|
||||
num_pages=row[1],
|
||||
status=row[2]
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
|
||||
class BatchWriter:
|
||||
def __init__(
|
||||
self,
|
||||
output_prefix: str,
|
||||
max_size_mb: int = 250,
|
||||
after_flush: Optional[Callable[[List[Any]], Any]] = None,
|
||||
):
|
||||
self.output_prefix = output_prefix
|
||||
self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||
self.batch_objects = []
|
||||
self.batch_size = 0
|
||||
self.after_flush = after_flush
|
||||
self.threads = []
|
||||
self.temp_file = None # The temporary file object
|
||||
self.temp_file_path = None # Path to the temporary file
|
||||
|
||||
parsed = urlparse(output_prefix)
|
||||
self.is_s3 = parsed.scheme in ("s3", "s3a", "s3n")
|
||||
|
||||
if not self.is_s3:
|
||||
os.makedirs(output_prefix, exist_ok=True)
|
||||
|
||||
def write_line(self, obj: Optional[Any]):
|
||||
if obj is None:
|
||||
return
|
||||
|
||||
line_bytes = orjson.dumps(obj)
|
||||
line_size = len(line_bytes) + 1 # +1 for newline
|
||||
|
||||
if self.batch_size + line_size > self.max_size:
|
||||
self._write_batch()
|
||||
|
||||
if self.batch_size == 0:
|
||||
# Open a new temporary file
|
||||
self.temp_file = tempfile.NamedTemporaryFile(mode="wb+", delete=False)
|
||||
self.temp_file_path = self.temp_file.name
|
||||
|
||||
self.temp_file.write(line_bytes + b"\n")
|
||||
self.batch_objects.append(obj)
|
||||
self.batch_size += line_size
|
||||
|
||||
def _write_batch(self):
|
||||
if self.batch_size == 0:
|
||||
return
|
||||
|
||||
# Close the temp file
|
||||
self.temp_file.flush()
|
||||
self.temp_file.close()
|
||||
|
||||
# Start a new thread to upload the temp file
|
||||
thread = threading.Thread(
|
||||
target=self._write_batch_to_file, args=(self.temp_file_path, self.batch_objects)
|
||||
)
|
||||
thread.start()
|
||||
self.threads.append(thread)
|
||||
|
||||
# Reset batch_objects and batch_size
|
||||
self.batch_objects = []
|
||||
self.batch_size = 0
|
||||
self.temp_file = None
|
||||
self.temp_file_path = None
|
||||
|
||||
def _write_batch_to_file(self, temp_file_path: str, batch_objects: List[Any]):
|
||||
# Compute hash based on file content
|
||||
hash_str = self._compute_hash(temp_file_path)
|
||||
output_path = self._get_output_path(hash_str)
|
||||
|
||||
if self.is_s3:
|
||||
bucket, key = parse_s3_path(output_path)
|
||||
|
||||
# Use the s3 client directly
|
||||
try:
|
||||
workspace_s3.upload_file(temp_file_path, bucket, key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload {temp_file_path} to {output_path}: {e}", exc_info=True)
|
||||
else:
|
||||
# Move the temp file to the output path
|
||||
os.rename(temp_file_path, output_path)
|
||||
|
||||
# After writing, call the after_flush callback if it is set
|
||||
if self.after_flush:
|
||||
self.after_flush(batch_objects)
|
||||
|
||||
os.remove(temp_file_path)
|
||||
|
||||
def _compute_hash(self, temp_file_path: str) -> str:
|
||||
"""Compute a 20-character SHA1 hash of the file content."""
|
||||
sha1 = hashlib.sha1()
|
||||
with open(temp_file_path, "rb") as f:
|
||||
while True:
|
||||
data = f.read(1024*1024)
|
||||
if not data:
|
||||
break
|
||||
sha1.update(data)
|
||||
return sha1.hexdigest()[:20]
|
||||
|
||||
def _get_output_path(self, hash_str: str) -> str:
|
||||
"""Generate the full output path with hash in the filename."""
|
||||
parsed = urlparse(self.output_prefix)
|
||||
if self.is_s3:
|
||||
bucket = parsed.netloc
|
||||
key = parsed.path.lstrip("/")
|
||||
if key and not key.endswith("/"):
|
||||
key += "/"
|
||||
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
|
||||
return f"s3://{bucket}/{full_key}"
|
||||
else:
|
||||
filename = f"output_{hash_str}.jsonl"
|
||||
return os.path.join(self.output_prefix, filename)
|
||||
|
||||
def close(self):
|
||||
self._write_batch()
|
||||
# Wait for all threads to finish
|
||||
for thread in self.threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
|
||||
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
|
||||
|
||||
if image_rotation != 0:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
with Image.open(BytesIO(image_bytes)) as img:
|
||||
rotated_img = img.rotate(-image_rotation, expand=True)
|
||||
|
||||
# Save the rotated image to a bytes buffer
|
||||
buffered = BytesIO()
|
||||
rotated_img.save(buffered, format="PNG")
|
||||
|
||||
# Encode the rotated image back to base64
|
||||
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
|
||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len)
|
||||
|
||||
return {
|
||||
"custom_id": f"{pretty_pdf_path}-{page}",
|
||||
"chat_messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
|
||||
content_bytes = get_s3_bytes(workspace_s3, inference_s3_path)
|
||||
|
||||
start_index = 0
|
||||
index_entries = []
|
||||
lines = content_bytes.splitlines(keepends=True) # Split content into lines as bytes
|
||||
for line in lines:
|
||||
line_length = len(line) # Length in bytes
|
||||
|
||||
try:
|
||||
# Parse the line directly as JSON
|
||||
data = orjson.loads(line)
|
||||
pdf_s3_path, page_num = parse_custom_id(data["custom_id"])
|
||||
|
||||
if data.get("completion_error", None) is not None:
|
||||
index_entries.append(DatabaseManager.BatchInferenceRecord(
|
||||
inference_s3_path=inference_s3_path,
|
||||
pdf_s3_path=pdf_s3_path,
|
||||
page_num=page_num,
|
||||
round=data["round"],
|
||||
start_index=start_index, # Byte offset in the original file
|
||||
length=line_length, # Length in bytes
|
||||
finish_reason="completion_error",
|
||||
error=data.get("completion_error", None)
|
||||
))
|
||||
else:
|
||||
# Try to parse the actual model response JSON
|
||||
assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected"
|
||||
|
||||
try:
|
||||
model_response_json = orjson.loads(data["outputs"][0]["text"])
|
||||
page_response = PageResponse(**model_response_json)
|
||||
|
||||
last_error = data.get("completion_error", None)
|
||||
|
||||
if not page_response.is_rotation_valid:
|
||||
last_error = "rotation_invalid"
|
||||
|
||||
index_entries.append(DatabaseManager.BatchInferenceRecord(
|
||||
inference_s3_path=inference_s3_path,
|
||||
pdf_s3_path=pdf_s3_path,
|
||||
page_num=page_num,
|
||||
round=data["round"],
|
||||
start_index=start_index, # Byte offset in the original file
|
||||
length=line_length, # Length in bytes
|
||||
finish_reason=data["outputs"][0]["finish_reason"],
|
||||
error=last_error,
|
||||
))
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
index_entries.append(DatabaseManager.BatchInferenceRecord(
|
||||
inference_s3_path=inference_s3_path,
|
||||
pdf_s3_path=pdf_s3_path,
|
||||
page_num=page_num,
|
||||
round=data["round"],
|
||||
start_index=start_index, # Byte offset in the original file
|
||||
length=line_length, # Length in bytes
|
||||
finish_reason=data["outputs"][0]["finish_reason"],
|
||||
error=error_type,
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing line in {inference_s3_path}: {e}")
|
||||
# Optionally, you might want to add an index entry indicating an error here
|
||||
|
||||
start_index += line_length # Increment by the number of bytes
|
||||
|
||||
return index_entries
|
||||
|
||||
|
||||
def get_pdf_num_pages(s3_path: str) -> Optional[int]:
|
||||
logger.debug(f"Startng to get_pdf_num_pages for {s3_path}")
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||
tf.write(get_s3_bytes(pdf_s3, s3_path))
|
||||
tf.flush()
|
||||
|
||||
reader = PdfReader(tf.name)
|
||||
logger.debug(f"Built reader for {s3_path}")
|
||||
return reader.get_num_pages()
|
||||
except Exception as ex:
|
||||
logger.warning(f"Warning, could not add {s3_path} due to {ex}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_page_data(page_index_entries: List[DatabaseManager.BatchInferenceRecord]) -> List[PageResponse]:
|
||||
usable_page_data = [get_s3_bytes(workspace_s3, page.inference_s3_path,
|
||||
start_index=page.start_index,
|
||||
end_index=page.start_index + page.length - 1) for page in page_index_entries]
|
||||
|
||||
usable_page_final_results = []
|
||||
for page_data in usable_page_data:
|
||||
data = orjson.loads(page_data)
|
||||
model_response_json = orjson.loads(data["outputs"][0]["text"])
|
||||
page_response = PageResponse(**model_response_json)
|
||||
usable_page_final_results.append(page_response)
|
||||
|
||||
return usable_page_final_results
|
||||
|
||||
|
||||
def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_round: int, target_longest_image_dim: int, target_anchor_text_len: int) -> list[dict]:
|
||||
db = DatabaseManager(s3_workspace, skip_init=True)
|
||||
|
||||
existing_pages = db.get_index_entries(pdf.s3_path)
|
||||
new_queries = []
|
||||
|
||||
# Shortcut out of downloading the actual PDF
|
||||
if set(page.page_num for page in existing_pages if page.is_usable()) == set(range(1, pdf.num_pages + 1)):
|
||||
db.close()
|
||||
return []
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||
tf.write(get_s3_bytes(pdf_s3, pdf.s3_path))
|
||||
tf.flush()
|
||||
|
||||
for target_page_num in range(1, pdf.num_pages + 1):
|
||||
# Is there an existing page that has no error
|
||||
if any(page.is_usable() and page.page_num == target_page_num for page in existing_pages):
|
||||
continue
|
||||
|
||||
has_errored_previously = sum(page.page_num == target_page_num for page in existing_pages)
|
||||
|
||||
if has_errored_previously:
|
||||
# Retry the page at least one more time regularly
|
||||
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})
|
||||
|
||||
# If the rotation was previously invalid, then apply a rotation
|
||||
rotated_page_data = _get_page_data([page for page in existing_pages if page.page_num == target_page_num and page.error == "rotation_invalid"])
|
||||
rotation_corrections = set(page_data.rotation_correction for page_data in rotated_page_data)
|
||||
for correction in rotation_corrections:
|
||||
logger.debug(f"Adding {correction}-degree rotation query for {pdf.s3_path}-{target_page_num}")
|
||||
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len, image_rotation=correction), "round": cur_round})
|
||||
|
||||
# TODO: Try to provide a smaller prompt hint if that was the error
|
||||
else:
|
||||
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})
|
||||
except Exception as ex:
|
||||
logger.warning(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}")
|
||||
|
||||
db.close()
|
||||
return new_queries
|
||||
|
||||
def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Optional[dict]:
|
||||
db = DatabaseManager(s3_workspace, skip_init=True)
|
||||
existing_pages = db.get_index_entries(pdf.s3_path)
|
||||
document_text = ""
|
||||
last_page_start_index = 0
|
||||
pdf_page_spans = []
|
||||
|
||||
# Error out quickly if this document cannot be assembled
|
||||
for target_page_num in range(1, pdf.num_pages + 1):
|
||||
usable_pages = [page for page in existing_pages if page.is_usable() and page.page_num == target_page_num]
|
||||
|
||||
if len(usable_pages) == 0:
|
||||
db.close()
|
||||
return None
|
||||
|
||||
for target_page_num in range(1, pdf.num_pages + 1):
|
||||
usable_pages = [page for page in existing_pages if page.is_usable() and page.page_num == target_page_num]
|
||||
usable_page_final_results = _get_page_data(usable_pages)
|
||||
|
||||
# Sort the pages:
|
||||
# 1. Prefer pages with `is_rotation_valid` set to True.
|
||||
# 2. Within those, sort by the length of the `natural_text` in descending order.
|
||||
usable_page_final_results.sort(
|
||||
key=lambda page: (not page.is_rotation_valid, -len(page.natural_text or ""))
|
||||
)
|
||||
|
||||
target_page_final_result = usable_page_final_results[0]
|
||||
|
||||
if target_page_final_result.natural_text is not None:
|
||||
document_text += target_page_final_result.natural_text + "\n"
|
||||
|
||||
pdf_page_spans.append([last_page_start_index, len(document_text), target_page_num])
|
||||
last_page_start_index = len(document_text)
|
||||
|
||||
metadata = {
|
||||
"Source-File": pdf.s3_path,
|
||||
"pdf-total-pages": pdf.num_pages,
|
||||
}
|
||||
id_ = hashlib.sha1(document_text.encode()).hexdigest()
|
||||
|
||||
dolma_doc = {
|
||||
"id": id_,
|
||||
"text": document_text,
|
||||
"source": "olmocr",
|
||||
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
|
||||
"metadata": metadata,
|
||||
"attributes": {
|
||||
"pdf_page_numbers": pdf_page_spans
|
||||
}
|
||||
}
|
||||
|
||||
db.close()
|
||||
return dolma_doc
|
||||
|
||||
def mark_pdfs_done(s3_workspace: str, dolma_docs: list[dict]):
|
||||
db = DatabaseManager(s3_workspace, skip_init=True)
|
||||
db.update_pdf_statuses({doc["metadata"]["Source-File"]: "completed" for doc in dolma_docs})
|
||||
db.close()
|
||||
|
||||
def get_current_round(s3_workspace: str) -> int:
|
||||
path = s3_workspace[5:]
|
||||
bucket, _, prefix = path.partition('/')
|
||||
|
||||
inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/')
|
||||
paginator = workspace_s3.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/')
|
||||
|
||||
round_numbers = []
|
||||
for page in page_iterator:
|
||||
for common_prefix in page.get('CommonPrefixes', []):
|
||||
round_prefix = common_prefix.get('Prefix')
|
||||
# Extract 'round_X' from the prefix
|
||||
round_dir = posixpath.basename(posixpath.dirname(round_prefix))
|
||||
if round_dir.startswith('round_'):
|
||||
try:
|
||||
round_num = int(round_dir[len('round_'):])
|
||||
round_numbers.append(round_num)
|
||||
except ValueError:
|
||||
pass
|
||||
if round_numbers:
|
||||
current_round = max(round_numbers) + 1
|
||||
else:
|
||||
current_round = 0
|
||||
return current_round
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
||||
parser.add_argument('--add_pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
|
||||
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
|
||||
parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
|
||||
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
|
||||
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
|
||||
parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB')
|
||||
parser.add_argument('--workers', type=int, help='Number of workers to run in the processpool')
|
||||
parser.add_argument('--reindex', action='store_true', default=False, help='Reindex all of the page_results')
|
||||
parser.add_argument('--skip_build_queries', action='store_true', default=False, help='Skip generation of new pdf page queries for batch inferencing')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.workspace_profile:
|
||||
workspace_session = boto3.Session(profile_name=args.workspace_profile)
|
||||
workspace_s3 = workspace_session.client("s3")
|
||||
|
||||
if args.pdf_profile:
|
||||
pdf_session = boto3.Session(profile_name=args.pdf_profile)
|
||||
pdf_s3 = pdf_session.client("s3")
|
||||
|
||||
db = DatabaseManager(args.workspace)
|
||||
logger.info(f"Loaded db at {db.db_path}")
|
||||
|
||||
if args.reindex:
|
||||
db.clear_index()
|
||||
logger.info("Cleared existing index.")
|
||||
|
||||
current_round = get_current_round(args.workspace)
|
||||
logger.info(f"Current round is {current_round}")
|
||||
|
||||
check_poppler_version()
|
||||
|
||||
# One shared executor to rule them all
|
||||
executor = ProcessPoolExecutor(max_workers=args.workers)
|
||||
|
||||
# If you have new PDFs, step one is to add them to the list
|
||||
if args.add_pdfs:
|
||||
if args.add_pdfs.startswith("s3://"):
|
||||
logger.info(f"Querying all PDFs at {args.add_pdfs}")
|
||||
|
||||
all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs)
|
||||
logger.info(f"Found {len(all_pdfs):,} total pdf paths")
|
||||
elif os.path.exists(args.add_pdfs):
|
||||
with open(args.add_pdfs, "r") as f:
|
||||
all_pdfs = [line.strip() for line in f.readlines() if len(line.strip()) > 0]
|
||||
else:
|
||||
raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
||||
|
||||
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
|
||||
logger.info(f"Need to import {len(all_pdfs):,} total new pdf paths")
|
||||
|
||||
future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs}
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Adding PDFs"):
|
||||
s3_path = future_to_path[future]
|
||||
num_pages = future.result()
|
||||
logger.debug(f"Got {num_pages} pages back for {s3_path}")
|
||||
if num_pages and not db.pdf_exists(s3_path):
|
||||
db.add_pdf(s3_path, num_pages, "pending")
|
||||
|
||||
logger.info("Completed adding new PDFs.")
|
||||
|
||||
# Now build an index of all the pages that were processed within the workspace so far
|
||||
logger.info("Indexing all batch inference sent to this workspace")
|
||||
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
|
||||
|
||||
inference_output_paths = {
|
||||
s3_path: etag for s3_path, etag in inference_output_paths.items()
|
||||
if not db.is_file_processed(s3_path, etag)
|
||||
}
|
||||
|
||||
logger.info(f"Found {len(inference_output_paths):,} new batch inference results to index")
|
||||
future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths.items()}
|
||||
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Indexing Inference Results"):
|
||||
s3_path, etag = future_to_path.pop(future)
|
||||
try:
|
||||
inference_records = future.result()
|
||||
|
||||
db.delete_index_entries_by_inference_s3_path(s3_path)
|
||||
db.add_index_entries(inference_records)
|
||||
db.update_processed_file(s3_path, etag=etag)
|
||||
except urllib3.exceptions.SSLError:
|
||||
logger.warning(f"Cannot load inference file {s3_path} due to SSL error, will retry another time")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to index inference file {s3_path}: {e}")
|
||||
|
||||
# Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it
|
||||
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
|
||||
if db.get_last_indexed_round() < current_round - 1:
|
||||
logger.warning(f"WARNING: No new batch inference results found, you need to run batch inference on {args.workspace}/inference_inputs/round_{current_round - 1}")
|
||||
potentially_done_pdfs = db.get_pdfs_by_status("pending")
|
||||
elif args.skip_build_queries:
|
||||
logger.info(f"Skipping generating new batch inference files")
|
||||
potentially_done_pdfs = db.get_pdfs_by_status("pending")
|
||||
else:
|
||||
logger.info("Creating batch inference files for new PDFs")
|
||||
pdf_list = list(db.get_pdfs_by_status("pending"))
|
||||
pdf_iter = iter(pdf_list)
|
||||
pending_futures = {}
|
||||
potentially_done_pdfs = []
|
||||
lines_written = 0
|
||||
new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb)
|
||||
total_pdfs = len(pdf_list)
|
||||
max_pending = 300
|
||||
|
||||
with tqdm(total=total_pdfs, desc="Building Batch Queries") as pbar:
|
||||
# Submit initial batch of futures
|
||||
for _ in range(min(max_pending, total_pdfs)):
|
||||
pdf = next(pdf_iter)
|
||||
future = executor.submit(
|
||||
build_pdf_queries, args.workspace, pdf, current_round, args.target_longest_image_dim,args.target_anchor_text_len,
|
||||
)
|
||||
pending_futures[future] = pdf
|
||||
|
||||
while pending_futures:
|
||||
# Wait for the next future to complete
|
||||
done, _ = concurrent.futures.wait(
|
||||
pending_futures.keys(),
|
||||
return_when=concurrent.futures.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
for future in done:
|
||||
pdf = pending_futures.pop(future)
|
||||
inference_lines = future.result()
|
||||
|
||||
if len(inference_lines) == 0:
|
||||
potentially_done_pdfs.append(pdf)
|
||||
|
||||
for line in inference_lines:
|
||||
lines_written += 1
|
||||
|
||||
if line is not None:
|
||||
new_inference_writer.write_line(line)
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# Submit a new future if there are more PDFs
|
||||
try:
|
||||
pdf = next(pdf_iter)
|
||||
future = executor.submit(
|
||||
build_pdf_queries, args.workspace, pdf, current_round, args.target_longest_image_dim,args.target_anchor_text_len,
|
||||
)
|
||||
pending_futures[future] = pdf
|
||||
except StopIteration:
|
||||
pass # No more PDFs to process
|
||||
|
||||
new_inference_writer.close()
|
||||
|
||||
if lines_written > 0:
|
||||
logger.info(f"Added {lines_written:,} new batch inference requests")
|
||||
|
||||
# Now, finally, assemble any potentially done docs into dolma documents
|
||||
logger.info(f"Assembling potentially finished PDFs into Dolma documents at {args.workspace}/output")
|
||||
future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs}
|
||||
new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb, after_flush=partial(mark_pdfs_done, args.workspace))
|
||||
|
||||
for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Assembling Dolma Docs"):
|
||||
pdf = future_to_path.pop(future)
|
||||
dolma_doc = future.result()
|
||||
|
||||
if dolma_doc is not None:
|
||||
new_output_writer.write_line(dolma_doc)
|
||||
|
||||
new_output_writer.close()
|
||||
|
||||
logger.info("Final statistics:")
|
||||
|
||||
# Output the number of PDFs in each status "pending" and "completed"
|
||||
pending_pdfs = db.get_pdfs_by_status("pending")
|
||||
completed_pdfs = db.get_pdfs_by_status("completed")
|
||||
|
||||
logger.info(f"Pending PDFs: {len(pending_pdfs):,} ({sum(doc.num_pages for doc in pending_pdfs):,} pages)")
|
||||
logger.info(f"Completed PDFs: {len(completed_pdfs):,} ({sum(doc.num_pages for doc in completed_pdfs):,} pages)")
|
||||
|
||||
# For each round, outputs a report of how many pages were processed, how many had errors, and a breakdown by (error, finish_reason)
|
||||
total_rounds = db.get_last_indexed_round() + 1
|
||||
for round_num in range(total_rounds):
|
||||
db.cursor.execute("""
|
||||
SELECT COUNT(*), error, finish_reason
|
||||
FROM page_results
|
||||
WHERE round = ?
|
||||
GROUP BY error, finish_reason
|
||||
""", (round_num,))
|
||||
|
||||
results = db.cursor.fetchall()
|
||||
|
||||
total_pages = sum(count for count, _, _ in results)
|
||||
logger.info(f"Inference Round {round_num} - {total_pages:,} pages processed:")
|
||||
|
||||
for count, error, finish_reason in results:
|
||||
error_str = error if error is not None else "None"
|
||||
logger.info(f" (error: {error_str}, finish_reason: {finish_reason}) -> {count:,} pages")
|
||||
|
||||
logger.info("Work finished, waiting for all workers to finish cleaning up")
|
||||
executor.shutdown(wait=True)
|
||||
db.close()
|
@ -1,116 +0,0 @@
|
||||
import concurrent.futures
|
||||
import threading
|
||||
import queue
|
||||
|
||||
class CappedFuture(concurrent.futures.Future):
|
||||
def __init__(self, semaphore):
|
||||
super().__init__()
|
||||
self._semaphore = semaphore
|
||||
self._result_retrieved = False
|
||||
self._underlying_future = None
|
||||
self._condition = threading.Condition()
|
||||
|
||||
def set_underlying_future(self, underlying_future):
|
||||
with self._condition:
|
||||
self._underlying_future = underlying_future
|
||||
# Transfer the result when the underlying future completes
|
||||
underlying_future.add_done_callback(self._transfer_result)
|
||||
|
||||
def _transfer_result(self, underlying_future):
|
||||
if underlying_future.cancelled():
|
||||
self.set_cancelled()
|
||||
elif underlying_future.exception() is not None:
|
||||
self.set_exception(underlying_future.exception())
|
||||
else:
|
||||
try:
|
||||
result = underlying_future.result()
|
||||
self.set_result(result)
|
||||
except Exception as e:
|
||||
self.set_exception(e)
|
||||
|
||||
def result(self, timeout=None):
|
||||
res = super().result(timeout)
|
||||
self._release_semaphore()
|
||||
return res
|
||||
|
||||
def exception(self, timeout=None):
|
||||
exc = super().exception(timeout)
|
||||
self._release_semaphore()
|
||||
return exc
|
||||
|
||||
def _release_semaphore(self):
|
||||
if not self._result_retrieved:
|
||||
self._result_retrieved = True
|
||||
self._semaphore.release()
|
||||
|
||||
def cancel(self):
|
||||
with self._condition:
|
||||
if self._underlying_future is not None:
|
||||
cancelled = self._underlying_future.cancel()
|
||||
if cancelled:
|
||||
super().cancel()
|
||||
return cancelled
|
||||
else:
|
||||
# Task has not been submitted yet; cancel directly
|
||||
return super().cancel()
|
||||
|
||||
def cancelled(self):
|
||||
return super().cancelled()
|
||||
|
||||
def running(self):
|
||||
with self._condition:
|
||||
if self._underlying_future is not None:
|
||||
return self._underlying_future.running()
|
||||
else:
|
||||
return False
|
||||
|
||||
def done(self):
|
||||
return super().done()
|
||||
|
||||
class CappedProcessPoolExecutor(concurrent.futures.Executor):
|
||||
def __init__(self, max_unprocessed=100, max_workers=None):
|
||||
self._max_unprocessed = max_unprocessed
|
||||
self._semaphore = threading.BoundedSemaphore(max_unprocessed)
|
||||
self._task_queue = queue.Queue()
|
||||
self._shutdown = threading.Event()
|
||||
self._shutdown_lock = threading.Lock()
|
||||
self._executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers)
|
||||
self._worker_thread = threading.Thread(target=self._worker)
|
||||
self._worker_thread.daemon = True
|
||||
self._worker_thread.start()
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
if self._shutdown.is_set():
|
||||
raise RuntimeError('Cannot submit new tasks after shutdown')
|
||||
# Create a CappedFuture to return to the user
|
||||
user_future = CappedFuture(self._semaphore)
|
||||
# Put the task in the queue
|
||||
self._task_queue.put((user_future, fn, args, kwargs))
|
||||
return user_future
|
||||
|
||||
def _worker(self):
|
||||
while True:
|
||||
if self._shutdown.is_set() and self._task_queue.empty():
|
||||
break
|
||||
try:
|
||||
user_future, fn, args, kwargs = self._task_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
self._semaphore.acquire()
|
||||
if user_future.cancelled():
|
||||
self._semaphore.release()
|
||||
continue
|
||||
# Submit the task to the underlying executor
|
||||
try:
|
||||
underlying_future = self._executor.submit(fn, *args, **kwargs)
|
||||
user_future.set_underlying_future(underlying_future)
|
||||
except Exception as e:
|
||||
user_future.set_exception(e)
|
||||
self._semaphore.release()
|
||||
continue
|
||||
|
||||
def shutdown(self, wait=True):
|
||||
with self._shutdown_lock:
|
||||
self._shutdown.set()
|
||||
self._worker_thread.join()
|
||||
self._executor.shutdown(wait=wait)
|
@ -15,7 +15,7 @@ from olmocr.s3_utils import (
|
||||
upload_zstd_csv,
|
||||
parse_s3_path
|
||||
)
|
||||
from pypdf import PdfReader
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -1,99 +0,0 @@
|
||||
import unittest
|
||||
import time
|
||||
import concurrent.futures
|
||||
from concurrent.futures import TimeoutError
|
||||
|
||||
# Assuming the CappedProcessPoolExecutor code is in a module named 'capped_executor'
|
||||
from olmocr.cappedpool import CappedProcessPoolExecutor
|
||||
|
||||
# Define functions at the top level to ensure they are picklable by multiprocessing
|
||||
|
||||
def square(x):
|
||||
return x * x
|
||||
|
||||
def raise_exception():
|
||||
raise ValueError("Test exception")
|
||||
|
||||
def sleep_and_return(x, sleep_time):
|
||||
time.sleep(sleep_time)
|
||||
return x
|
||||
|
||||
def task(counter, max_counter, counter_lock):
|
||||
with counter_lock:
|
||||
counter.value += 1
|
||||
print(f"Task incrementing counter to {counter.value}")
|
||||
if counter.value > max_counter.value:
|
||||
max_counter.value = counter.value
|
||||
time.sleep(0.5)
|
||||
with counter_lock:
|
||||
counter.value -= 1
|
||||
return True
|
||||
|
||||
class TestCappedProcessPoolExecutor(unittest.TestCase):
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test that tasks are executed and results are correct."""
|
||||
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
|
||||
futures = [executor.submit(square, i) for i in range(10)]
|
||||
results = [f.result() for f in futures]
|
||||
expected = [i * i for i in range(10)]
|
||||
self.assertEqual(results, expected)
|
||||
|
||||
def test_exception_handling(self):
|
||||
"""Test that exceptions in tasks are properly raised."""
|
||||
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
|
||||
future = executor.submit(raise_exception)
|
||||
with self.assertRaises(ValueError):
|
||||
future.result()
|
||||
|
||||
def test_cancellation(self):
|
||||
"""Test that tasks can be cancelled before execution."""
|
||||
with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor:
|
||||
future = executor.submit(time.sleep, 5)
|
||||
# Try to cancel immediately
|
||||
cancelled = future.cancel()
|
||||
self.assertTrue(cancelled)
|
||||
self.assertTrue(future.cancelled())
|
||||
# Attempt to get result; should raise CancelledError
|
||||
with self.assertRaises(concurrent.futures.CancelledError):
|
||||
future.result()
|
||||
|
||||
def test_shutdown(self):
|
||||
"""Test that the executor shuts down properly and does not accept new tasks."""
|
||||
executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4)
|
||||
future = executor.submit(time.sleep, 1)
|
||||
executor.shutdown(wait=True)
|
||||
with self.assertRaises(RuntimeError):
|
||||
executor.submit(time.sleep, 1)
|
||||
|
||||
def test_capping_behavior(self):
|
||||
"""Test that the number of concurrent tasks does not exceed max_unprocessed."""
|
||||
max_unprocessed = 3
|
||||
with CappedProcessPoolExecutor(max_unprocessed=max_unprocessed, max_workers=10) as executor:
|
||||
from multiprocessing import Manager
|
||||
|
||||
manager = Manager()
|
||||
counter = manager.Value('i', 0)
|
||||
max_counter = manager.Value('i', 0)
|
||||
counter_lock = manager.Lock()
|
||||
|
||||
futures = [executor.submit(task, counter, max_counter, counter_lock) for _ in range(10)]
|
||||
|
||||
for index, f in enumerate(futures):
|
||||
print(f"Future {index} returned {f.result()}")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
print(max_counter.value)
|
||||
self.assertLessEqual(max_counter.value, max_unprocessed)
|
||||
|
||||
def test_submit_after_shutdown(self):
|
||||
"""Test that submitting tasks after shutdown raises an error."""
|
||||
executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4)
|
||||
executor.shutdown(wait=True)
|
||||
with self.assertRaises(RuntimeError):
|
||||
executor.submit(square, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -17,7 +17,7 @@ from io import BytesIO
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration
|
||||
from pathlib import Path
|
||||
from olmocr.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text, download_directory
|
||||
from olmocr.pipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text, download_directory
|
||||
from olmocr.prompts import PageResponse
|
||||
from httpx import AsyncClient
|
||||
import torch.nn.functional as F
|
||||
|
Loading…
x
Reference in New Issue
Block a user