olmocr/olmocr/pipeline.py

1285 lines
55 KiB
Python
Raw Normal View History

import argparse
2024-11-08 08:14:20 -08:00
import asyncio
2025-01-29 15:25:10 -08:00
import atexit
import base64
2024-11-11 11:58:45 -08:00
import datetime
2025-01-29 15:25:10 -08:00
import hashlib
import json
import logging
2024-11-23 21:41:49 +00:00
import multiprocessing
2025-01-29 15:25:10 -08:00
import os
import random
import re
import shutil
import sys
import tempfile
import time
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from concurrent.futures.process import BrokenProcessPool
from dataclasses import dataclass
2025-07-23 03:40:05 +00:00
from functools import cache
2025-01-29 15:25:10 -08:00
from io import BytesIO
from urllib.parse import urlparse
2025-01-29 15:25:10 -08:00
import boto3
import httpx
2024-12-04 17:56:45 +00:00
from botocore.exceptions import ClientError
from huggingface_hub import snapshot_download
2024-11-07 23:24:01 +00:00
from PIL import Image
2024-11-08 11:04:58 -08:00
from pypdf import PdfReader
2025-01-29 15:25:10 -08:00
from tqdm import tqdm
2024-11-07 18:21:23 +00:00
2025-01-29 15:25:10 -08:00
from olmocr.check import (
check_poppler_version,
check_torch_gpu_available,
)
from olmocr.data.renderpdf import render_pdf_to_base64png
2025-01-29 15:25:10 -08:00
from olmocr.filter.filter import Language, PdfFilter
from olmocr.image_utils import convert_image_to_pdf_bytes, is_jpeg, is_png
from olmocr.metrics import MetricsKeeper, WorkerTracker
from olmocr.prompts import PageResponse, build_no_anchoring_yaml_prompt
2025-01-29 15:25:10 -08:00
from olmocr.prompts.anchor import get_anchor_text
from olmocr.s3_utils import (
download_directory,
2025-01-29 15:25:10 -08:00
download_zstd_csv,
expand_s3_glob,
get_s3_bytes,
get_s3_bytes_with_backoff,
parse_s3_path,
)
2025-07-23 03:40:05 +00:00
from olmocr.train.dataloader import FrontMatterParser
from olmocr.version import VERSION
2025-08-13 20:21:04 +00:00
from olmocr.work_queue import LocalBackend, S3Backend, WorkQueue
2024-11-11 14:26:15 -08:00
# Initialize logger
2024-11-07 18:21:23 +00:00
logger = logging.getLogger(__name__)
2024-11-11 14:26:15 -08:00
logger.setLevel(logging.DEBUG)
2024-11-12 09:33:53 -08:00
logger.propagate = False
2025-06-02 18:07:31 +00:00
server_logger = logging.getLogger("vllm")
server_logger.propagate = False
2024-11-11 14:26:15 -08:00
2025-01-29 15:30:39 -08:00
file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
2024-11-11 14:26:15 -08:00
file_handler.setLevel(logging.DEBUG)
2025-01-29 15:30:39 -08:00
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
2024-11-11 14:26:15 -08:00
2024-11-12 09:33:53 -08:00
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
2025-01-29 15:30:39 -08:00
console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
2024-11-12 09:33:53 -08:00
2024-11-11 14:26:15 -08:00
# Add handlers to the logger
logger.addHandler(file_handler)
2024-11-12 09:33:53 -08:00
logger.addHandler(console_handler)
2025-06-02 18:07:31 +00:00
server_logger.addHandler(file_handler)
2025-08-13 17:39:02 +00:00
server_logger.addHandler(console_handler)
2024-11-07 18:21:23 +00:00
# Quiet logs from pypdf
logging.getLogger("pypdf").setLevel(logging.ERROR)
2024-11-11 13:09:09 -08:00
# Global s3 clients fo the whole script, we have two separate ones in case your workspace and your pdfs are in different accounts
2025-01-29 15:30:39 -08:00
workspace_s3 = boto3.client("s3")
pdf_s3 = boto3.client("s3")
2024-11-11 13:09:09 -08:00
# Global variables for token statistics
2025-01-29 15:30:39 -08:00
metrics = MetricsKeeper(window=60 * 5)
2024-11-12 13:28:39 -08:00
tracker = WorkerTracker()
2024-11-11 13:09:09 -08:00
# Process pool for offloading cpu bound work, like calculating anchor texts, max 32 workers, otherwise it can spawn way too many workers on a big machine
2025-01-29 15:30:39 -08:00
process_pool = ProcessPoolExecutor(max_workers=min(multiprocessing.cpu_count() // 2 + 1, 32), mp_context=multiprocessing.get_context("spawn"))
2024-11-08 11:04:58 -08:00
2024-11-21 23:23:11 +00:00
# Filter object, cached so it will only get loaded when/if you need it
get_pdf_filter = cache(lambda: PdfFilter(languages_to_keep={Language.ENGLISH, None}, apply_download_spam_check=True, apply_form_check=True))
2024-11-21 10:20:58 -08:00
# Specify a default port, but it can be overridden by args
2025-06-02 18:07:31 +00:00
BASE_SERVER_PORT = 30024
2024-11-14 13:13:27 -08:00
2025-01-29 15:30:39 -08:00
2024-11-11 11:46:49 -08:00
@dataclass(frozen=True)
class PageResult:
s3_path: str
page_num: int
response: PageResponse
2024-11-12 08:34:25 -08:00
input_tokens: int
output_tokens: int
2024-11-19 14:59:20 -08:00
is_fallback: bool
2024-11-11 13:09:09 -08:00
2024-11-07 18:21:23 +00:00
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, image_rotation: int = 0) -> dict:
MAX_TOKENS = 4500
2024-11-07 23:24:01 +00:00
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
2024-11-08 09:59:27 -08:00
# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
2025-07-14 17:35:29 +00:00
image_base64 = await asyncio.to_thread(render_pdf_to_base64png, local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
2024-11-07 23:24:01 +00:00
if image_rotation != 0:
image_bytes = base64.b64decode(image_base64)
with Image.open(BytesIO(image_bytes)) as img:
2025-08-04 17:53:48 +00:00
if image_rotation == 90:
tranpose = Image.Transpose.ROTATE_90
elif image_rotation == 180:
tranpose = Image.Transpose.ROTATE_180
else:
tranpose = Image.Transpose.ROTATE_270
rotated_img = img.transpose(tranpose)
2024-11-07 23:24:01 +00:00
# Save the rotated image to a bytes buffer
buffered = BytesIO()
rotated_img.save(buffered, format="PNG")
# Encode the rotated image back to base64
2025-01-29 15:30:39 -08:00
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
2024-11-07 23:24:01 +00:00
return {
2025-07-01 17:01:33 +00:00
"model": "olmocr",
2024-11-08 11:04:58 -08:00
"messages": [
2024-11-07 23:24:01 +00:00
{
"role": "user",
"content": [
{"type": "text", "text": build_no_anchoring_yaml_prompt()},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
2024-11-07 23:24:01 +00:00
],
}
],
2024-11-08 11:04:58 -08:00
"max_tokens": MAX_TOKENS,
"temperature": 0.0,
2024-11-07 23:24:01 +00:00
}
# Manual simple implementation of HTTP Post
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
# that at the scale of 100M+ requests, that they deadlock in different strange ways
async def apost(url, json_data):
parsed_url = urlparse(url)
host = parsed_url.hostname
port = parsed_url.port or 80
path = parsed_url.path or "/"
writer = None
try:
reader, writer = await asyncio.open_connection(host, port)
json_payload = json.dumps(json_data)
request = (
f"POST {path} HTTP/1.1\r\n"
f"Host: {host}\r\n"
f"Content-Type: application/json\r\n"
f"Content-Length: {len(json_payload)}\r\n"
f"Connection: close\r\n\r\n"
f"{json_payload}"
)
writer.write(request.encode())
await writer.drain()
# Read status line
status_line = await reader.readline()
if not status_line:
raise ConnectionError("No response from server")
2025-01-29 15:30:39 -08:00
status_parts = status_line.decode().strip().split(" ", 2)
if len(status_parts) < 2:
raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
status_code = int(status_parts[1])
# Read headers
headers = {}
while True:
line = await reader.readline()
2025-01-29 15:30:39 -08:00
if line in (b"\r\n", b"\n", b""):
break
2025-01-29 15:30:39 -08:00
key, _, value = line.decode().partition(":")
headers[key.strip().lower()] = value.strip()
# Read response body
2025-01-29 15:30:39 -08:00
if "content-length" in headers:
body_length = int(headers["content-length"])
response_body = await reader.readexactly(body_length)
else:
raise ConnectionError("Anything other than fixed content length responses are not implemented yet")
return status_code, response_body
except Exception as e:
# Pass through errors
raise e
finally:
# But just make sure to close the socket on your way out
if writer is not None:
try:
writer.close()
await writer.wait_closed()
except:
pass
2025-01-28 14:16:53 -08:00
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
2025-06-02 18:07:31 +00:00
COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries
2025-07-23 20:37:48 +00:00
MODEL_MAX_CONTEXT = 16384
2025-07-22 19:35:40 +00:00
TEMPERATURE_BY_ATTEMPT = [0.1, 0.1, 0.2, 0.3, 0.5, 0.8, 0.9, 1.0]
2024-11-14 09:02:49 -08:00
exponential_backoffs = 0
2025-08-04 18:21:31 +00:00
cumulative_rotation = 0 # Track cumulative rotation instead of local
attempt = 0
2025-01-28 14:16:53 -08:00
await tracker.track_work(worker_id, f"{pdf_orig_path}-{page_num}", "started")
2024-11-11 11:58:45 -08:00
while attempt < MAX_RETRIES:
lookup_attempt = min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
2025-05-29 23:23:02 +00:00
query = await build_page_query(
pdf_local_path,
page_num,
args.target_longest_image_dim,
2025-08-04 18:21:31 +00:00
image_rotation=cumulative_rotation,
2025-05-29 23:23:02 +00:00
)
2025-05-29 16:01:26 +00:00
# Change temperature as number of attempts increases to overcome repetition issues at expense of quality
query["temperature"] = TEMPERATURE_BY_ATTEMPT[lookup_attempt]
2024-11-11 11:46:49 -08:00
# Enable guided decoding regex if needed
if args.guided_decoding:
2025-07-23 03:40:05 +00:00
query["guided_regex"] = (
r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n(?:---|---\n[\s\S]+)"
)
2025-08-13 17:39:02 +00:00
logger.debug(f"Built page query for {pdf_orig_path}-{page_num}")
2024-11-11 14:38:26 -08:00
try:
status_code, response_body = await apost(COMPLETION_URL, json_data=query)
2024-11-18 09:03:24 -08:00
if status_code == 400:
raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response")
elif status_code == 500:
raise ValueError(f"Got InternalServerError from server: {response_body}, skipping this response")
elif status_code != 200:
raise ValueError(f"Error http status {status_code}")
2024-11-15 10:03:26 -08:00
base_response_data = json.loads(response_body)
2025-01-29 15:30:39 -08:00
2025-07-23 20:37:48 +00:00
if base_response_data["usage"]["total_tokens"] > MODEL_MAX_CONTEXT:
raise ValueError(f"Response exceeded model_max_context of {MODEL_MAX_CONTEXT}, cannot use this response")
2025-07-23 03:40:05 +00:00
if base_response_data["choices"][0]["finish_reason"] != "stop":
raise ValueError("Response did not finish with reason code 'stop', cannot use this response")
2025-01-29 15:30:39 -08:00
metrics.add_metrics(
2025-06-02 18:07:31 +00:00
server_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
server_output_tokens=base_response_data["usage"].get("completion_tokens", 0),
2025-01-29 15:30:39 -08:00
)
2024-11-15 12:48:36 -08:00
model_response_markdown = base_response_data["choices"][0]["message"]["content"]
parser = FrontMatterParser(front_matter_class=PageResponse)
front_matter, text = parser._extract_front_matter_and_text(model_response_markdown)
page_response = parser._parse_front_matter(front_matter, text)
2024-11-15 12:48:36 -08:00
2024-11-18 08:29:32 -08:00
if not page_response.is_rotation_valid and attempt < MAX_RETRIES - 1:
2025-01-29 15:30:39 -08:00
logger.info(
f"Got invalid_page rotation for {pdf_orig_path}-{page_num} attempt {attempt}, retrying with {page_response.rotation_correction} rotation"
)
2025-08-04 18:21:31 +00:00
# Add the rotation correction to the cumulative rotation
cumulative_rotation = (cumulative_rotation + page_response.rotation_correction) % 360
logger.info(f"Cumulative rotation is now {cumulative_rotation} degrees")
2025-01-28 14:16:53 -08:00
raise ValueError(f"invalid_page rotation for {pdf_orig_path}-{page_num}")
2024-11-18 08:29:32 -08:00
metrics.add_metrics(**{"completed_pages": 1, f"finished_on_attempt_{attempt}": 1})
2025-01-28 14:16:53 -08:00
await tracker.track_work(worker_id, f"{pdf_orig_path}-{page_num}", "finished")
2024-11-15 12:48:36 -08:00
return PageResult(
2025-01-28 14:16:53 -08:00
pdf_orig_path,
2024-11-15 12:48:36 -08:00
page_num,
page_response,
input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
2024-11-19 14:59:20 -08:00
output_tokens=base_response_data["usage"].get("completion_tokens", 0),
is_fallback=False,
2024-11-15 12:48:36 -08:00
)
except (ConnectionError, OSError, asyncio.TimeoutError) as e:
2025-01-28 14:16:53 -08:00
logger.warning(f"Client error on attempt {attempt} for {pdf_orig_path}-{page_num}: {type(e)} {e}")
2025-01-29 15:30:39 -08:00
# Now we want to do exponential backoff, and not count this as an actual page retry
2025-06-02 18:07:31 +00:00
# Page retrys are supposed to be for fixing bad results from the model, but actual requests to vllm
# are supposed to work. Probably this means that the server is just restarting
2025-01-29 15:30:39 -08:00
sleep_delay = 10 * (2**exponential_backoffs)
2024-11-14 09:02:49 -08:00
exponential_backoffs += 1
2025-01-28 14:16:53 -08:00
logger.info(f"Sleeping for {sleep_delay} seconds on {pdf_orig_path}-{page_num} to allow server restart")
2024-11-14 09:02:49 -08:00
await asyncio.sleep(sleep_delay)
2024-11-14 12:06:13 -08:00
except asyncio.CancelledError:
2025-01-28 14:16:53 -08:00
logger.info(f"Process page {pdf_orig_path}-{page_num} cancelled")
await tracker.track_work(worker_id, f"{pdf_orig_path}-{page_num}", "cancelled")
2024-11-14 12:06:13 -08:00
raise
2024-11-14 14:13:04 -08:00
except json.JSONDecodeError as e:
2025-01-28 14:16:53 -08:00
logger.warning(f"JSON decode error on attempt {attempt} for {pdf_orig_path}-{page_num}: {e}")
2024-11-14 14:13:04 -08:00
attempt += 1
2024-11-15 10:03:26 -08:00
except ValueError as e:
2025-01-28 14:16:53 -08:00
logger.warning(f"ValueError on attempt {attempt} for {pdf_orig_path}-{page_num}: {type(e)} - {e}")
2024-11-15 10:03:26 -08:00
attempt += 1
2024-11-11 14:38:26 -08:00
except Exception as e:
2025-01-28 14:16:53 -08:00
logger.exception(f"Unexpected error on attempt {attempt} for {pdf_orig_path}-{page_num}: {type(e)} - {e}")
attempt += 1
2024-11-11 14:38:26 -08:00
2025-01-28 14:16:53 -08:00
logger.error(f"Failed to process {pdf_orig_path}-{page_num} after {MAX_RETRIES} attempts.")
metrics.add_metrics(failed_pages=1)
2025-01-28 14:16:53 -08:00
await tracker.track_work(worker_id, f"{pdf_orig_path}-{page_num}", "errored")
2025-01-29 15:30:39 -08:00
2024-11-19 14:59:20 -08:00
return PageResult(
2025-01-28 14:16:53 -08:00
pdf_orig_path,
2024-11-19 14:59:20 -08:00
page_num,
2025-01-29 15:30:39 -08:00
PageResponse(
natural_text=get_anchor_text(pdf_local_path, page_num, pdf_engine="pdftotext"),
primary_language=None,
is_rotation_valid=True,
rotation_correction=0,
is_table=False,
is_diagram=False,
),
2024-11-19 14:59:20 -08:00
input_tokens=0,
output_tokens=0,
2025-01-29 15:30:39 -08:00
is_fallback=True,
2024-11-19 14:59:20 -08:00
)
2024-11-08 09:59:27 -08:00
2025-07-23 03:40:05 +00:00
2025-01-28 14:16:53 -08:00
async def process_pdf(args, worker_id: int, pdf_orig_path: str):
2025-06-12 17:21:21 +00:00
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf", delete=False) as tf:
2024-12-04 17:56:45 +00:00
try:
2025-01-28 14:16:53 -08:00
data = await asyncio.to_thread(lambda: get_s3_bytes_with_backoff(pdf_s3, pdf_orig_path))
2024-12-04 17:56:45 +00:00
tf.write(data)
tf.flush()
except ClientError as ex:
2025-01-29 15:30:39 -08:00
if ex.response["Error"]["Code"] == "NoSuchKey":
2025-01-28 14:16:53 -08:00
logger.info(f"S3 File Not found, skipping it completely {pdf_orig_path}")
2024-12-04 17:56:45 +00:00
return None
else:
raise
2024-11-08 09:59:27 -08:00
if is_png(tf.name) or is_jpeg(tf.name):
logger.info(f"Converting {pdf_orig_path} from image to PDF format...")
tf.seek(0)
tf.write(convert_image_to_pdf_bytes(tf.name))
tf.flush()
2025-06-12 17:21:21 +00:00
try:
2024-11-14 08:55:20 -08:00
try:
reader = PdfReader(tf.name)
num_pages = reader.get_num_pages()
except:
2025-01-28 14:16:53 -08:00
logger.exception(f"Could not count number of pages for {pdf_orig_path}, aborting document")
2024-11-14 08:55:20 -08:00
return None
2024-11-08 10:19:00 -08:00
2025-08-13 17:39:02 +00:00
logger.debug(f"Got {num_pages} pages to do for {pdf_orig_path} in worker {worker_id}")
2024-11-12 13:44:20 -08:00
2024-11-21 23:23:11 +00:00
if args.apply_filter and get_pdf_filter().filter_out_pdf(tf.name):
2025-01-28 14:16:53 -08:00
logger.info(f"Filtering out pdf {pdf_orig_path}")
2024-11-21 10:20:58 -08:00
return None
2024-11-08 10:19:00 -08:00
# List to hold the tasks for processing each page
page_tasks = []
2024-11-14 12:06:13 -08:00
page_results = []
2024-11-08 10:19:00 -08:00
try:
2024-11-14 12:06:13 -08:00
async with asyncio.TaskGroup() as tg:
for page_num in range(1, num_pages + 1):
2025-01-28 14:16:53 -08:00
task = tg.create_task(process_page(args, worker_id, pdf_orig_path, tf.name, page_num))
2024-11-14 12:06:13 -08:00
page_tasks.append(task)
2024-11-14 12:06:13 -08:00
# Collect the results from the entire task group, assuming no exceptions
page_results = [task.result() for task in page_tasks]
2024-11-19 14:59:20 -08:00
num_fallback_pages = sum(page_result.is_fallback for page_result in page_results)
if num_fallback_pages / num_pages > args.max_page_error_rate:
2025-01-29 15:30:39 -08:00
logger.error(
f"Document {pdf_orig_path} has {num_fallback_pages} fallback pages out of {num_pages} exceeding max_page_error_rate of {args.max_page_error_rate}, discarding document."
)
2024-11-19 14:59:20 -08:00
return None
elif num_fallback_pages > 0:
2025-01-29 15:30:39 -08:00
logger.warning(
f"Document {pdf_orig_path} processed with {num_fallback_pages} fallback pages out of {num_pages}, proceeding to build Dolma document."
)
2024-11-19 14:59:20 -08:00
2025-01-28 14:16:53 -08:00
return build_dolma_document(pdf_orig_path, page_results)
2024-11-14 12:06:13 -08:00
except Exception as e:
# Check for ExceptionGroup with BrokenProcessPool
if isinstance(e, ExceptionGroup):
broken_pool, other = e.split(BrokenProcessPool)
if broken_pool is not None: # Found at least one BrokenProcessPool
logger.critical("Encountered BrokenProcessPool, exiting process.")
sys.exit(1)
2025-01-28 14:16:53 -08:00
logger.exception(f"Exception in process_pdf for {pdf_orig_path}: {e}")
# You can't build a dolma doc with even 1 failed page, so just get out of here
# However, you don't want to propagate an exception higher up and cancel the entire work_group
return None
2025-06-12 17:21:21 +00:00
finally:
if os.path.exists(tf.name):
os.unlink(tf.name)
2024-11-14 12:06:13 -08:00
2025-01-28 14:16:53 -08:00
def build_dolma_document(pdf_orig_path, page_results):
2024-11-14 12:06:13 -08:00
# Build the document text and page spans
document_text = ""
pdf_page_spans = []
current_char_pos = 0
for index, page_result in enumerate(page_results):
if page_result.response.natural_text is not None:
content = page_result.response.natural_text + ("\n" if index < len(page_results) - 1 else "")
else:
content = ""
start_pos = current_char_pos
document_text += content
current_char_pos = len(document_text)
pdf_page_spans.append([start_pos, current_char_pos, page_result.page_num])
if not document_text:
2025-01-28 14:16:53 -08:00
logger.info(f"No document text for {pdf_orig_path}")
2024-11-14 12:06:13 -08:00
return None # Return None if the document text is empty
# Build the Dolma document
metadata = {
2025-01-28 14:16:53 -08:00
"Source-File": pdf_orig_path,
"olmocr-version": VERSION,
2024-11-14 12:06:13 -08:00
"pdf-total-pages": len(page_results),
"total-input-tokens": sum(page.input_tokens for page in page_results),
2024-11-19 15:11:02 -08:00
"total-output-tokens": sum(page.output_tokens for page in page_results),
"total-fallback-pages": sum(page.is_fallback for page in page_results),
2024-11-14 12:06:13 -08:00
}
2024-11-08 10:19:00 -08:00
2024-11-14 12:06:13 -08:00
id_ = hashlib.sha1(document_text.encode()).hexdigest()
dolma_doc = {
"id": id_,
"text": document_text,
"source": "olmocr",
2024-11-14 12:06:13 -08:00
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
"attributes": {
"pdf_page_numbers": pdf_page_spans,
"primary_language": [p.response.primary_language for p in page_results],
"is_rotation_valid": [p.response.is_rotation_valid for p in page_results],
"rotation_correction": [p.response.rotation_correction for p in page_results],
"is_table": [p.response.is_table for p in page_results],
"is_diagram": [p.response.is_diagram for p in page_results],
},
2024-11-14 12:06:13 -08:00
}
return dolma_doc
2024-11-08 09:59:27 -08:00
2024-11-18 11:04:51 -08:00
async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
2024-11-08 08:14:20 -08:00
while True:
2024-11-18 11:04:51 -08:00
# Wait until allowed to proceed
await semaphore.acquire()
2024-11-11 13:09:09 -08:00
2024-11-18 11:04:51 -08:00
work_item = await work_queue.get_work()
2024-11-13 13:05:57 -08:00
2024-11-18 11:04:51 -08:00
if work_item is None:
logger.info(f"Worker {worker_id} exiting due to empty queue")
semaphore.release()
break
2024-11-18 11:04:51 -08:00
logger.info(f"Worker {worker_id} processing work item {work_item.hash}")
await tracker.clear_work(worker_id)
2024-11-13 13:05:57 -08:00
2025-01-29 15:30:39 -08:00
try:
async with asyncio.TaskGroup() as tg:
2025-01-27 20:45:28 +00:00
dolma_tasks = [tg.create_task(process_pdf(args, worker_id, pdf)) for pdf in work_item.work_paths]
logger.info(f"Created all tasks for {work_item.hash}")
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
2024-11-14 12:06:13 -08:00
dolma_docs = []
for task in dolma_tasks:
try:
result = task.result()
except:
# some dolma doc creations may have failed
pass
if result is not None:
dolma_docs.append(result)
2025-01-29 15:30:39 -08:00
2024-11-18 11:04:51 -08:00
logger.info(f"Got {len(dolma_docs)} docs for {work_item.hash}")
2024-11-11 13:09:09 -08:00
# Write the Dolma documents to a local temporary file in JSONL format
2025-01-29 15:30:39 -08:00
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tf:
2024-11-11 13:09:09 -08:00
for doc in dolma_docs:
tf.write(json.dumps(doc))
2025-01-29 15:30:39 -08:00
tf.write("\n")
2024-11-11 13:09:09 -08:00
tf.flush()
temp_path = tf.name
2024-11-11 13:09:09 -08:00
try:
2024-11-11 13:09:09 -08:00
# Define the output S3 path using the work_hash
2025-01-29 15:30:39 -08:00
output_final_path = os.path.join(args.workspace, "results", f"output_{work_item.hash}.jsonl")
2024-11-11 13:09:09 -08:00
if output_final_path.startswith("s3://"):
bucket, key = parse_s3_path(output_final_path)
workspace_s3.upload_file(temp_path, bucket, key)
else:
# Ensure the results directory exists for local workspace
os.makedirs(os.path.dirname(output_final_path), exist_ok=True)
shutil.copyfile(temp_path, output_final_path)
finally:
# Clean up the temporary file
if os.path.exists(temp_path):
os.unlink(temp_path)
# If --markdown flag is set, also write the natural text to markdown files
if args.markdown:
logger.info(f"Writing {len(dolma_docs)} markdown files for {work_item.hash}")
for doc in dolma_docs:
source_file = doc["metadata"]["Source-File"]
natural_text = doc["text"]
# Create the output markdown path that preserves the folder structure
if source_file.startswith("s3://"):
# Extract the path after the bucket name for S3 sources
parsed = urlparse(source_file)
relative_path = parsed.path.lstrip("/")
else:
# For local files, use the full path
relative_path = source_file
# Change the extension to .md
md_filename = os.path.splitext(os.path.basename(relative_path))[0] + ".md"
# Get the directory path without the filename
dir_path = os.path.dirname(relative_path)
# Create the output markdown path
markdown_dir = os.path.join(args.workspace, "markdown", dir_path)
markdown_path = os.path.join(markdown_dir, md_filename)
# Create the directory structure if it doesn't exist
if markdown_path.startswith("s3://"):
# For S3 paths, we'll create a temporary file and upload it
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as md_tf:
md_tf.write(natural_text)
md_tf.flush()
md_temp_path = md_tf.name
try:
md_bucket, md_key = parse_s3_path(markdown_path)
workspace_s3.upload_file(md_temp_path, md_bucket, md_key)
finally:
# Make sure to clean up the temporary file even if upload fails
if os.path.exists(md_temp_path):
os.unlink(md_temp_path)
else:
# For local paths, create the directory structure and write the file
os.makedirs(markdown_dir, exist_ok=True)
with open(markdown_path, "w") as md_f:
md_f.write(natural_text)
2024-11-11 13:09:09 -08:00
2024-11-12 09:33:53 -08:00
# Update finished token counts from successful documents
2025-01-29 15:30:39 -08:00
metrics.add_metrics(
finished_input_tokens=sum(doc["metadata"]["total-input-tokens"] for doc in dolma_docs),
finished_output_tokens=sum(doc["metadata"]["total-output-tokens"] for doc in dolma_docs),
)
await work_queue.mark_done(work_item)
2024-11-11 13:09:09 -08:00
except Exception as e:
2024-11-18 11:04:51 -08:00
logger.exception(f"Exception occurred while processing work_hash {work_item.hash}: {e}")
2024-11-11 13:09:09 -08:00
finally:
2024-11-18 11:04:51 -08:00
semaphore.release()
2024-11-08 08:14:20 -08:00
2025-08-03 23:00:06 -04:00
async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=None):
2024-11-18 08:25:36 -08:00
cmd = [
2025-06-02 18:07:31 +00:00
"vllm",
"serve",
2025-01-29 15:30:39 -08:00
model_name_or_path,
"--port",
str(BASE_SERVER_PORT),
2025-06-02 18:07:31 +00:00
"--disable-log-requests",
"--uvicorn-log-level",
"warning",
"--served-model-name",
2025-07-01 17:01:33 +00:00
"olmocr",
"--tensor-parallel-size",
str(args.tensor_parallel_size),
"--data-parallel-size",
str(args.data_parallel_size),
2024-11-18 08:25:36 -08:00
]
if args.gpu_memory_utilization is not None:
cmd.extend(["--gpu-memory-utilization", str(args.gpu_memory_utilization)])
2025-07-23 03:40:05 +00:00
if args.max_model_len is not None:
2025-07-23 03:40:05 +00:00
cmd.extend(["--max-model-len", str(args.max_model_len)])
2025-08-03 23:00:06 -04:00
if unknown_args:
cmd.extend(unknown_args)
2024-11-18 08:25:36 -08:00
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
2024-11-15 13:18:13 -08:00
)
2024-11-15 13:18:13 -08:00
# Ensure the subprocess is terminated on exit
2024-11-11 10:24:47 -08:00
def _kill_proc():
2025-07-23 21:38:29 +00:00
try:
proc.terminate()
except:
logger.info("VLLM Process already terminated")
2024-11-11 10:24:47 -08:00
atexit.register(_kill_proc)
2024-11-08 10:36:09 -08:00
2024-11-15 13:18:13 -08:00
# Shared variables between tasks
last_running_req, last_queue_req = 0, 0
server_printed_ready_message = False
last_semaphore_release = time.time()
2024-11-15 13:18:13 -08:00
async def process_line(line):
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
2025-06-02 18:07:31 +00:00
server_logger.info(line)
2024-11-16 08:16:11 -08:00
if "Detected errors during sampling" in line:
logger.error("Cannot continue, sampling errors detected, model is probably corrupt")
sys.exit(1)
2025-06-02 21:22:25 +00:00
if not server_printed_ready_message and ("The server is fired up and ready to roll!" in line or "Starting vLLM API server" in line):
server_printed_ready_message = True
2024-11-18 11:04:51 -08:00
last_semaphore_release = time.time()
match = re.search(r"Running: (\d+)", line)
if match:
last_running_req = int(match.group(1))
2024-11-15 13:18:13 -08:00
match = re.search(r"(?:Waiting|Pending):\s*(\d+)", line)
if match:
2024-11-18 08:25:36 -08:00
last_queue_req = int(match.group(1))
2025-06-02 18:07:31 +00:00
logger.info(f"vllm running req: {last_running_req} queue req: {last_queue_req}")
async def read_stream(stream):
while True:
line = await stream.readline()
if not line:
break
try:
2025-01-29 15:30:39 -08:00
line = line.decode("utf-8").rstrip()
await process_line(line)
except Exception as ex:
logger.warning(f"Got {ex} when reading log line from inference server, skipping")
2024-11-15 13:18:13 -08:00
async def timeout_task():
2024-11-18 08:25:36 -08:00
nonlocal last_running_req, last_queue_req, last_semaphore_release
2024-11-15 13:18:13 -08:00
try:
while True:
2024-11-15 13:19:23 -08:00
await asyncio.sleep(1)
if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
2024-11-15 13:18:13 -08:00
semaphore.release()
last_semaphore_release = time.time()
2024-11-18 08:25:36 -08:00
logger.info("Semaphore released, allowing a worker to proceed.")
2024-11-15 13:18:13 -08:00
except asyncio.CancelledError:
pass # Clean up if the task is cancelled
# Start tasks to read stdout, stderr, and handle timeout logic
stdout_task = asyncio.create_task(read_stream(proc.stdout))
stderr_task = asyncio.create_task(read_stream(proc.stderr))
2024-11-15 13:19:23 -08:00
timeout_task = asyncio.create_task(timeout_task())
2024-11-25 09:48:05 -08:00
try:
await proc.wait()
except asyncio.CancelledError:
2025-06-02 18:07:31 +00:00
logger.info("Got cancellation request for VLLM server")
2024-11-25 09:48:05 -08:00
proc.terminate()
2025-07-23 16:48:56 +00:00
try:
await asyncio.wait_for(proc.wait(), timeout=10.0)
except asyncio.TimeoutError:
logger.warning("VLLM server did not terminate within 10 seconds")
2024-11-25 09:48:05 -08:00
raise
2024-11-15 13:19:23 -08:00
timeout_task.cancel()
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
2024-11-08 10:36:09 -08:00
2025-08-03 23:00:06 -04:00
async def vllm_server_host(model_name_or_path, args, semaphore, unknown_args=None):
MAX_RETRIES = 5
retry = 0
while retry < MAX_RETRIES:
2025-08-03 23:00:06 -04:00
await vllm_server_task(model_name_or_path, args, semaphore, unknown_args)
2025-06-02 18:07:31 +00:00
logger.warning("VLLM server task ended")
retry += 1
2024-12-02 23:56:45 +00:00
if retry >= MAX_RETRIES:
2025-06-02 18:07:31 +00:00
logger.error(f"Ended up starting the vllm server more than {retry} times, cancelling pipeline")
2025-01-29 15:47:57 -08:00
logger.error("")
logger.error(
"Please make sure vllm is installed according to the latest instructions here: https://docs.vllm.ai/en/stable/getting_started/installation/gpu.html"
)
2024-12-02 23:56:45 +00:00
sys.exit(1)
2025-06-02 18:07:31 +00:00
async def vllm_server_ready():
2024-11-08 11:38:56 -08:00
max_attempts = 300
2024-11-08 10:19:00 -08:00
delay_sec = 1
2025-06-02 18:07:31 +00:00
url = f"http://localhost:{BASE_SERVER_PORT}/v1/models"
2024-11-08 10:19:00 -08:00
for attempt in range(1, max_attempts + 1):
try:
async with httpx.AsyncClient() as session:
response = await session.get(url)
if response.status_code == 200:
2025-06-02 18:07:31 +00:00
logger.info("vllm server is ready.")
return
else:
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
2025-03-13 22:29:27 +00:00
except Exception:
2025-06-02 18:07:31 +00:00
logger.warning(f"Attempt {attempt}: Please wait for vllm server to become ready...")
2024-11-08 10:36:09 -08:00
2024-11-08 10:19:00 -08:00
await asyncio.sleep(delay_sec)
2025-06-02 18:07:31 +00:00
raise Exception("vllm server did not become ready after waiting.")
2024-11-08 10:19:00 -08:00
2024-11-12 13:28:39 -08:00
2025-06-12 21:14:00 +00:00
async def download_model(model_name_or_path: str, max_retries: int = 5):
for retry in range(max_retries):
try:
if model_name_or_path.startswith("s3://") or model_name_or_path.startswith("gs://") or model_name_or_path.startswith("weka://"):
logger.info(f"Downloading model directory from '{model_name_or_path}'")
model_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "olmocr", "model")
# Delete existing model cache directory if it exists
if os.path.exists(model_cache_dir):
shutil.rmtree(model_cache_dir)
download_directory([model_name_or_path], model_cache_dir)
return model_cache_dir
elif os.path.isabs(model_name_or_path) and os.path.isdir(model_name_or_path):
logger.info(f"Using local model path at '{model_name_or_path}'")
return model_name_or_path
else:
logger.info(f"Downloading model with hugging face '{model_name_or_path}'")
snapshot_download(repo_id=model_name_or_path)
return model_name_or_path
except Exception:
if retry == max_retries - 1:
2025-06-12 21:14:00 +00:00
raise # Raise on final attempt and fail the job
sleep_time = random.randrange(2, 20) * 2**retry
logger.exception(f"Could not download model, sleeping for {sleep_time} seconds to retry ({retry + 1}/{max_retries})")
await asyncio.sleep(random.randrange(10, 30) * 2**retry)
2024-11-18 11:04:51 -08:00
async def metrics_reporter(work_queue):
2024-11-12 12:56:35 -08:00
while True:
2024-11-12 13:28:39 -08:00
# Leading newlines preserve table formatting in logs
2024-11-18 11:04:51 -08:00
logger.info(f"Queue remaining: {work_queue.size}")
2024-11-12 13:28:39 -08:00
logger.info("\n" + str(metrics))
logger.info("\n" + str(await tracker.get_status_table()))
2024-11-12 12:56:35 -08:00
await asyncio.sleep(10)
2024-11-13 10:25:35 -08:00
2024-11-12 15:56:51 -08:00
def submit_beaker_job(args):
from beaker import ( # type: ignore
2024-11-13 08:00:14 -08:00
Beaker,
Constraints,
EnvVar,
ExperimentSpec,
ImageSource,
Priority,
ResultSpec,
SecretNotFound,
TaskContext,
TaskResources,
TaskSpec,
)
2025-01-29 15:30:39 -08:00
2024-11-13 08:00:14 -08:00
b = Beaker.from_env(default_workspace=args.beaker_workspace)
account = b.account.whoami()
2024-11-13 11:26:46 -08:00
owner = account.name
2025-06-02 23:05:48 +00:00
beaker_image = f"jakep/olmocr-inference-{VERSION}"
2024-11-13 08:00:14 -08:00
task_name = f"olmocr-{os.path.basename(args.workspace.rstrip('/'))}"
2024-11-13 08:00:14 -08:00
2024-11-19 11:48:45 -08:00
# Take out --beaker flag so the workers will just run things
2024-11-13 10:25:35 -08:00
args_list = [arg for arg in sys.argv[1:] if arg != "--beaker"]
2024-11-13 08:00:14 -08:00
2024-11-19 11:48:45 -08:00
# Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally
2025-01-29 15:30:39 -08:00
args_list = [arg for i, arg in enumerate(args_list) if not (arg.startswith("--pdfs") or (i > 0 and args_list[i - 1] == "--pdfs"))]
2024-11-19 11:48:45 -08:00
2024-11-13 11:26:46 -08:00
try:
b.secret.get(f"{owner}-WEKA_ACCESS_KEY_ID", args.beaker_workspace)
b.secret.get(f"{owner}-WEKA_SECRET_ACCESS_KEY", args.beaker_workspace)
b.secret.get(f"{owner}-AWS_CREDENTIALS_FILE", args.beaker_workspace)
except SecretNotFound:
2025-01-29 15:30:39 -08:00
print(
f"Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace {args.beaker_workspace}? [y/n]"
)
2024-11-13 11:26:46 -08:00
if input().strip().lower() != "y":
print("Exiting...")
sys.exit(1)
b.secret.write(f"{owner}-WEKA_ACCESS_KEY_ID", os.environ.get("WEKA_ACCESS_KEY_ID", ""), args.beaker_workspace)
b.secret.write(f"{owner}-WEKA_SECRET_ACCESS_KEY", os.environ.get("WEKA_SECRET_ACCESS_KEY", ""), args.beaker_workspace)
2025-01-30 13:42:42 -08:00
b.secret.write(
f"{owner}-AWS_CREDENTIALS_FILE",
open(os.path.join(os.path.expanduser("~"), ".aws", "credentials")).read(),
args.beaker_workspace,
)
env_var_secrets = [
EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
]
2024-11-13 11:26:46 -08:00
2025-01-30 22:30:39 +00:00
try:
b.secret.get("OLMOCR_PREVIEW_HF_TOKEN", args.beaker_workspace)
env_var_secrets.append(EnvVar(name="HF_TOKEN", secret="OLMOCR_PREVIEW_HF_TOKEN"))
except SecretNotFound:
pass
2024-11-18 13:07:27 -08:00
try:
2025-01-29 15:47:57 -08:00
b.secret.get("OE_DATA_GCS_SA_KEY", args.beaker_workspace)
2025-01-30 13:42:42 -08:00
env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
2024-11-18 13:07:27 -08:00
except SecretNotFound:
print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):")
lines = []
prev_empty = False
for line in iter(input, None):
if not line and prev_empty:
break
prev_empty = not line
lines.append(line)
gcs_sa_key = "\n".join(lines[:-1]).strip() # Remove the last empty line
if gcs_sa_key:
2025-01-29 15:47:57 -08:00
b.secret.write("OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
2025-01-30 13:42:42 -08:00
env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
2024-11-13 11:26:46 -08:00
2024-11-13 08:00:14 -08:00
# Create the experiment spec
experiment_spec = ExperimentSpec(
budget="ai2/oe-data",
description=task_name,
tasks=[
TaskSpec(
name=task_name,
propagate_failure=False,
propagate_preemption=False,
2024-11-14 09:02:49 -08:00
replicas=args.beaker_gpus,
2024-11-13 08:00:14 -08:00
context=TaskContext(
2024-11-14 09:10:28 -08:00
priority=Priority(args.beaker_priority),
2024-11-13 08:00:14 -08:00
preemptible=True,
),
image=ImageSource(beaker=beaker_image),
2025-01-30 22:14:57 +00:00
command=["python", "-m", "olmocr.pipeline"] + args_list,
2025-01-30 13:42:42 -08:00
env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
2024-11-14 09:02:49 -08:00
resources=TaskResources(gpu_count=1),
2024-11-15 13:30:27 -08:00
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
2024-11-13 08:00:14 -08:00
result=ResultSpec(path="/noop-results"),
)
],
)
2025-01-29 15:30:39 -08:00
2024-11-13 08:00:14 -08:00
experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace)
2025-01-29 15:30:39 -08:00
2024-11-13 08:00:14 -08:00
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
2024-11-12 15:56:51 -08:00
2024-11-18 11:04:51 -08:00
2025-05-06 21:21:06 +00:00
def print_stats(args, root_work_queue):
2024-12-10 17:18:10 +00:00
LONG_CONTEXT_THRESHOLD = 32768
2025-01-10 19:38:42 +00:00
2025-01-28 14:29:46 -08:00
assert args.workspace.startswith("s3://"), "Printing stats functionality only works with s3 workspaces for now."
2024-11-18 07:57:39 -08:00
# Get total work items and completed items
2024-11-18 11:04:51 -08:00
index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "results", "*.jsonl")
2025-01-29 15:30:39 -08:00
2024-11-18 07:57:39 -08:00
done_work_items = expand_s3_glob(workspace_s3, output_glob)
2025-05-06 21:21:06 +00:00
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
work_queue = {}
for line in work_queue_lines:
if line.strip():
parts = root_work_queue._decode_csv_row(line.strip())
if parts: # Ensure we have at least one part
work_queue[parts[0]] = parts[1:]
2025-01-29 15:30:39 -08:00
2024-11-19 13:41:32 -08:00
total_items = len(work_queue)
2024-11-18 07:57:39 -08:00
completed_items = len(done_work_items)
2025-01-29 15:30:39 -08:00
2024-11-18 07:57:39 -08:00
def process_output_file(s3_path):
try:
data = get_s3_bytes(workspace_s3, s3_path)
doc_count = 0
total_input_tokens = 0
total_output_tokens = 0
2024-11-18 11:50:22 -08:00
total_pages = 0
2025-01-29 15:30:39 -08:00
total_fallback_pages = 0
2024-11-18 11:50:22 -08:00
processed_paths = set()
2025-01-29 15:30:39 -08:00
2024-12-10 17:18:10 +00:00
# Counters for long context docs within a single file
long_context_docs = 0
long_context_tokens = 0
2025-01-29 15:30:39 -08:00
for line in data.decode("utf-8").splitlines():
2024-11-18 07:57:39 -08:00
if line.strip():
doc = json.loads(line)
doc_count += 1
2024-12-10 17:18:10 +00:00
doc_input_tokens = doc["metadata"].get("total-input-tokens", 0)
doc_output_tokens = doc["metadata"].get("total-output-tokens", 0)
doc_pages = doc["metadata"].get("pdf-total-pages", 0)
doc_fallback_pages = doc["metadata"].get("total-fallback-pages", 0)
total_input_tokens += doc_input_tokens
total_output_tokens += doc_output_tokens
total_pages += doc_pages
total_fallback_pages += doc_fallback_pages
2024-11-18 11:50:22 -08:00
processed_paths.add(doc["metadata"]["Source-File"])
2024-12-10 17:18:10 +00:00
# Check if this doc exceeds the long context threshold
if doc_output_tokens > LONG_CONTEXT_THRESHOLD:
long_context_docs += 1
long_context_tokens += doc_output_tokens
2025-01-29 15:30:39 -08:00
return (
doc_count,
total_input_tokens,
total_output_tokens,
total_pages,
total_fallback_pages,
processed_paths,
long_context_docs,
long_context_tokens,
)
2024-11-18 07:57:39 -08:00
except Exception as e:
logger.warning(f"Error processing {s3_path}: {e}")
2024-12-10 17:18:10 +00:00
return 0, 0, 0, 0, 0, set(), 0, 0
2025-01-29 15:30:39 -08:00
print(f"\nCompleted work items {completed_items:,} out of {total_items:,}: {completed_items/total_items*100:.2f}%")
2024-11-18 07:57:39 -08:00
print("\nProcessing output files...")
docs_total = 0
input_tokens_total = 0
output_tokens_total = 0
2024-11-18 11:50:22 -08:00
pages_total = 0
2024-11-20 10:42:26 -08:00
fallback_pages_total = 0
2024-11-18 11:50:22 -08:00
all_processed_paths = set()
original_paths = set()
2025-01-29 15:30:39 -08:00
2024-12-10 17:18:10 +00:00
# Counters for long context documents across all files
long_context_docs_count = 0
long_context_tokens_total = 0
2024-11-18 11:50:22 -08:00
# First collect all original PDF paths
2024-11-19 13:41:32 -08:00
for done_work_item in done_work_items:
if match := re.search(r"output_(\w+).jsonl", done_work_item):
done_work_hash = match.group(1)
2025-05-06 21:21:06 +00:00
if done_work_hash in work_queue:
original_paths.update(work_queue[done_work_hash])
2025-01-29 15:30:39 -08:00
2024-11-18 11:50:22 -08:00
with ThreadPoolExecutor() as executor:
2024-11-18 07:57:39 -08:00
futures = {executor.submit(process_output_file, item): item for item in done_work_items}
2025-01-29 15:30:39 -08:00
2024-11-18 11:50:22 -08:00
for future in tqdm(as_completed(futures), total=len(futures)):
2025-01-29 15:30:39 -08:00
(doc_count, input_tokens, output_tokens, pages, fallback_pages, processed_paths, long_context_docs, long_context_tokens) = future.result()
2024-11-18 07:57:39 -08:00
docs_total += doc_count
input_tokens_total += input_tokens
output_tokens_total += output_tokens
2024-11-18 11:50:22 -08:00
pages_total += pages
2024-11-20 10:42:26 -08:00
fallback_pages_total += fallback_pages
2024-11-18 11:50:22 -08:00
all_processed_paths.update(processed_paths)
2024-12-10 17:18:10 +00:00
long_context_docs_count += long_context_docs
long_context_tokens_total += long_context_tokens
2025-01-29 15:30:39 -08:00
2024-11-18 11:50:22 -08:00
skipped_paths = original_paths - all_processed_paths
2024-11-19 13:41:32 -08:00
2025-01-29 15:47:57 -08:00
print("\nWork Items Status:")
2024-11-19 13:41:32 -08:00
print(f"Total work items: {total_items:,}")
print(f"Completed items: {completed_items:,}")
print(f"Remaining items: {total_items - completed_items:,}")
2025-01-29 15:30:39 -08:00
2025-01-29 15:47:57 -08:00
print("\nResults:")
2024-11-18 07:57:39 -08:00
print(f"Total documents processed: {docs_total:,}")
2024-11-18 11:50:22 -08:00
print(f"Total documents skipped: {len(skipped_paths):,}")
2024-11-20 10:42:26 -08:00
print(f"Total pages on fallback: {fallback_pages_total:,}")
2024-11-18 11:50:22 -08:00
print(f"Total pages processed: {pages_total:,}")
2025-01-29 15:30:39 -08:00
2024-11-18 11:50:22 -08:00
print(f"\nTotal output tokens: {output_tokens_total:,}")
2024-11-20 23:57:10 +00:00
print(f"Projected output tokens: {round((output_tokens_total/max(1, completed_items))*total_items):,}")
2024-11-18 11:50:22 -08:00
print(f"\nAverage pages per doc: {pages_total/max(1,docs_total):,.1f}")
2024-11-18 07:57:39 -08:00
print(f"Average output tokens per doc: {output_tokens_total/max(1,docs_total):,.1f}")
2024-11-18 11:50:22 -08:00
print(f"Average output tokens per page: {output_tokens_total/max(1,pages_total):,.1f}")
2024-12-10 17:18:10 +00:00
# Print long context documents stats
print(f"\nLong Context Documents (>{LONG_CONTEXT_THRESHOLD} tokens): {long_context_docs_count:,}")
print(f"Total tokens in long context documents: {long_context_tokens_total:,}")
2024-11-12 13:28:39 -08:00
2024-11-08 08:14:20 -08:00
async def main():
2025-08-03 23:25:57 -04:00
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline.")
2025-01-29 15:30:39 -08:00
parser.add_argument(
"workspace",
help="The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ",
)
parser.add_argument(
"--pdfs",
2025-01-30 12:48:10 -08:00
nargs="*",
2025-01-29 15:30:39 -08:00
help="Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths",
default=None,
)
2025-07-24 18:47:03 +00:00
parser.add_argument(
"--model",
2025-08-13 20:57:15 +00:00
help="Path where the model is located, allenai/olmOCR-7B-0825-FP8 is the default, can be local, s3, or hugging face.",
default="allenai/olmOCR-7B-0825-FP8",
2025-07-24 18:47:03 +00:00
)
# More detailed config options, usually you shouldn't have to change these
2025-01-29 15:30:39 -08:00
parser.add_argument("--workspace_profile", help="S3 configuration profile for accessing the workspace", default=None)
parser.add_argument("--pdf_profile", help="S3 configuration profile for accessing the raw pdf documents", default=None)
parser.add_argument("--pages_per_group", type=int, default=500, help="Aiming for this many pdf pages per work item group")
parser.add_argument("--max_page_retries", type=int, default=8, help="Max number of times we will retry rendering a page")
parser.add_argument("--max_page_error_rate", type=float, default=0.004, help="Rate of allowable failed pages in a document, 1/250 by default")
parser.add_argument("--workers", type=int, default=20, help="Number of workers to run at a time")
2025-01-29 15:30:39 -08:00
parser.add_argument("--apply_filter", action="store_true", help="Apply basic filtering to English pdfs which are not forms, and not likely seo spam")
parser.add_argument("--stats", action="store_true", help="Instead of running any job, reports some statistics about the current workspace")
parser.add_argument("--markdown", action="store_true", help="Also write natural text to markdown files preserving the folder structure of the input pdfs")
2025-07-17 19:46:35 +00:00
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288)
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
2025-07-01 17:44:02 +00:00
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
2024-11-12 15:56:51 -08:00
2025-08-03 23:00:06 -04:00
vllm_group = parser.add_argument_group(
2025-08-04 18:21:47 +00:00
"VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM."
2025-08-03 23:00:06 -04:00
)
2025-07-24 18:50:30 +00:00
vllm_group.add_argument(
"--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve)."
)
2025-07-24 18:47:03 +00:00
vllm_group.add_argument("--max_model_len", type=int, default=16384, help="Upper bound (tokens) vLLM will allocate KV-cache for, lower if VLLM won't start")
vllm_group.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="Tensor parallel size for vLLM")
vllm_group.add_argument("--data-parallel-size", "-dp", type=int, default=1, help="Data parallel size for vLLM")
vllm_group.add_argument("--port", type=int, default=30024, help="Port to use for the VLLM server")
2024-11-12 15:56:51 -08:00
# Beaker/job running stuff
2025-07-24 18:47:03 +00:00
beaker_group = parser.add_argument_group("beaker/cluster execution")
beaker_group.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
beaker_group.add_argument("--beaker_workspace", help="Beaker workspace to submit to", default="ai2/olmocr")
beaker_group.add_argument(
2025-01-29 15:30:39 -08:00
"--beaker_cluster",
help="Beaker clusters you want to run on",
default=["ai2/jupiter-cirrascale-2", "ai2/ceres-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"],
)
2025-07-24 18:47:03 +00:00
beaker_group.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
beaker_group.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
2025-08-03 23:00:06 -04:00
args, unknown_args = parser.parse_known_args()
logger.info(
"If you run out of GPU memory during start-up or get 'KV cache is larger than available memory' errors, retry with lower values, e.g. --gpu_memory_utilization 0.80 --max_model_len 16384"
)
2025-07-23 03:40:05 +00:00
2024-11-13 13:23:29 -08:00
global workspace_s3, pdf_s3
2025-06-02 18:07:31 +00:00
# set the global BASE_SERVER_PORT from args
global BASE_SERVER_PORT
BASE_SERVER_PORT = args.port
2024-11-13 13:23:29 -08:00
2024-11-14 08:49:12 -08:00
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
if "BEAKER_JOB_NAME" in os.environ:
2025-01-29 15:30:39 -08:00
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
2024-11-13 12:35:40 -08:00
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
2024-11-13 11:26:46 -08:00
with open(cred_path, "w") as f:
f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
2025-01-29 15:30:39 -08:00
cred_path = os.path.join(os.path.expanduser("~"), ".gcs", "credentials")
2024-11-18 13:20:28 -08:00
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE"))
2024-11-18 13:58:25 -08:00
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
2025-01-29 15:30:39 -08:00
workspace_s3 = boto3.client("s3")
pdf_s3 = boto3.client("s3")
2024-11-13 11:26:46 -08:00
# Wait a little bit so that not all beaker jobs in a task start at the same time and download the model at the same time
2025-04-18 15:47:31 +00:00
replica_count = int(os.environ.get("BEAKER_REPLICA_COUNT", "1"))
interval = 10 if (replica_count - 1) * 10 <= 30 else 30 / max(1, replica_count - 1)
sleep_time = int(os.environ.get("BEAKER_REPLICA_RANK", "0")) * interval
2025-04-18 15:47:31 +00:00
logger.info(f"Beaker job sleeping for {sleep_time} seconds to stagger model downloads")
await asyncio.sleep(sleep_time)
2024-11-07 18:21:23 +00:00
if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.client("s3")
if args.pdf_profile:
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3")
2025-01-30 21:44:22 +00:00
# We need poppler to load the initial pdfs, even if we are not processing them here
2024-11-07 23:24:01 +00:00
check_poppler_version()
2024-11-07 21:08:46 +00:00
2024-11-18 11:04:51 -08:00
# Create work queue
2025-01-27 20:45:28 +00:00
if args.workspace.startswith("s3://"):
2025-08-13 20:20:27 +00:00
work_queue = WorkQueue(S3Backend(workspace_s3, args.workspace))
2025-01-27 20:45:28 +00:00
else:
2025-08-13 20:20:27 +00:00
work_queue = WorkQueue(LocalBackend(args.workspace))
2024-11-18 11:04:51 -08:00
2024-11-07 18:21:23 +00:00
if args.pdfs:
2024-11-13 12:59:52 -08:00
logger.info("Got --pdfs argument, going to add to the work queue")
2025-01-30 12:48:10 -08:00
pdf_work_paths = set()
for pdf_path in args.pdfs:
# Expand s3 paths
if pdf_path.startswith("s3://"):
logger.info(f"Expanding s3 glob at {pdf_path}")
pdf_work_paths |= set(expand_s3_glob(pdf_s3, pdf_path))
elif os.path.exists(pdf_path):
if (
pdf_path.lower().endswith(".pdf")
or pdf_path.lower().endswith(".png")
or pdf_path.lower().endswith(".jpg")
or pdf_path.lower().endswith(".jpeg")
):
if open(pdf_path, "rb").read(4) == b"%PDF":
logger.info(f"Loading file at {pdf_path} as PDF document")
pdf_work_paths.add(pdf_path)
elif is_png(pdf_path) or is_jpeg(pdf_path):
logger.info(f"Loading file at {pdf_path} as image document")
pdf_work_paths.add(pdf_path)
else:
logger.warning(f"File at {pdf_path} is not a valid PDF")
elif pdf_path.lower().endswith(".txt"):
logger.info(f"Loading file at {pdf_path} as list of paths")
with open(pdf_path, "r") as f:
2025-01-30 12:48:10 -08:00
pdf_work_paths |= set(filter(None, (line.strip() for line in f)))
else:
raise ValueError(f"Unsupported file extension for {pdf_path}")
2025-01-28 15:12:28 -08:00
else:
2025-01-30 12:48:10 -08:00
raise ValueError("pdfs argument needs to be either a local path, an s3 path, or an s3 glob pattern...")
2024-11-18 11:04:51 -08:00
2025-01-28 15:03:31 -08:00
logger.info(f"Found {len(pdf_work_paths):,} total pdf paths to add")
2024-11-18 11:04:51 -08:00
# Estimate average pages per pdf
2025-01-28 15:03:31 -08:00
sample_size = min(100, len(pdf_work_paths))
sampled_pdfs = random.sample(list(pdf_work_paths), sample_size)
2024-11-18 11:04:51 -08:00
page_counts = []
for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"):
try:
# Download the PDF to a temp file
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
2025-01-28 15:03:31 -08:00
tmp_file.write(get_s3_bytes(pdf_s3, pdf))
2024-11-18 11:04:51 -08:00
tmp_file.flush()
if is_png(tmp_file.name) or is_jpeg(tmp_file.name):
page_counts.append(1)
else:
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
2024-11-18 11:04:51 -08:00
except Exception as e:
logger.warning(f"Failed to read {pdf}: {e}")
if page_counts:
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
else:
logger.warning("Could not read any PDFs to estimate average page count.")
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
items_per_group = max(1, int(args.pages_per_group / avg_pages_per_pdf))
logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
# Now call populate_queue
2025-01-28 15:03:31 -08:00
await work_queue.populate_queue(pdf_work_paths, items_per_group)
2024-11-08 08:14:20 -08:00
2024-11-18 07:57:39 -08:00
if args.stats:
2025-05-06 21:21:06 +00:00
print_stats(args, work_queue)
2024-11-18 07:57:39 -08:00
return
2024-11-12 15:56:51 -08:00
if args.beaker:
2024-11-13 08:00:14 -08:00
submit_beaker_job(args)
2024-11-12 15:56:51 -08:00
return
2025-01-30 21:44:22 +00:00
# If you get this far, then you are doing inference and need a GPU
# check_sglang_version()
2025-01-30 21:44:22 +00:00
check_torch_gpu_available()
2024-11-13 12:59:52 -08:00
logger.info(f"Starting pipeline with PID {os.getpid()}")
# Download the model before you do anything else
model_name_or_path = await download_model(args.model)
2024-11-18 11:04:51 -08:00
# Initialize the work queue
qsize = await work_queue.initialize_queue()
2024-11-18 11:04:51 -08:00
if qsize == 0:
logger.info("No work to do, exiting")
return
# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)
2025-08-03 23:00:06 -04:00
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args))
2025-06-02 18:07:31 +00:00
await vllm_server_ready()
2024-11-07 18:21:23 +00:00
2024-11-14 12:06:13 -08:00
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
2024-11-12 12:56:35 -08:00
2024-11-08 08:14:20 -08:00
# Create worker tasks to process the queue concurrently.
worker_tasks = []
2024-11-08 08:14:20 -08:00
for i in range(args.workers):
2024-11-12 09:33:53 -08:00
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
worker_tasks.append(task)
2024-11-07 18:21:23 +00:00
2024-11-18 11:04:51 -08:00
# Wait for all worker tasks to finish
await asyncio.gather(*worker_tasks)
2024-11-08 08:14:20 -08:00
# Wait for server to stop
2024-11-15 12:54:45 -08:00
process_pool.shutdown(wait=False)
2025-06-02 18:07:31 +00:00
vllm_server.cancel()
2024-11-12 12:56:35 -08:00
metrics_task.cancel()
2025-07-23 16:48:56 +00:00
# Wait for cancelled tasks to complete
await asyncio.gather(vllm_server, metrics_task, return_exceptions=True)
2025-06-02 21:10:30 +00:00
# Output final metrics summary
metrics_summary = metrics.get_metrics_summary()
logger.info("=" * 80)
logger.info("FINAL METRICS SUMMARY")
logger.info("=" * 80)
logger.info(f"Total elapsed time: {metrics_summary['elapsed_time_seconds']:.2f} seconds")
2025-06-02 21:10:30 +00:00
# Output token counts and rates
total_metrics = metrics_summary["total_metrics"]
rates = metrics_summary["rates"]
2025-06-02 21:40:14 +00:00
logger.info(f"Total Server Input tokens: {total_metrics.get('server_input_tokens', 0):,}")
logger.info(f"Total Server Output tokens: {total_metrics.get('server_output_tokens', 0):,}")
2025-06-02 21:40:14 +00:00
logger.info(f"Finished input tokens: {total_metrics.get('finished_input_tokens', 0):,}")
logger.info(f"Finished output tokens: {total_metrics.get('finished_output_tokens', 0):,}")
logger.info(f"Completed pages: {total_metrics.get('completed_pages', 0):,}")
logger.info(f"Failed pages: {total_metrics.get('failed_pages', 0):,}")
logger.info(
f"Page Failure rate: {total_metrics.get('failed_pages', 0) / max(total_metrics.get('completed_pages', 0) + total_metrics.get('failed_pages', 0), 1) * 100:.2f}%"
)
# Output finished_on_attempt statistics
2025-07-23 16:48:56 +00:00
logger.info("")
logger.info("Pages finished by attempt number:")
2025-07-23 03:40:05 +00:00
total_finished = sum(total_metrics.get(f"finished_on_attempt_{i}", 0) for i in range(args.max_page_retries))
cumulative = 0
2025-07-23 03:40:05 +00:00
for i in range(args.max_page_retries):
2025-07-23 03:40:05 +00:00
if f"finished_on_attempt_{i}" in total_metrics:
count = total_metrics[f"finished_on_attempt_{i}"]
cumulative += count
percentage = (count / total_finished * 100) if total_finished > 0 else 0
cumulative_percentage = (cumulative / total_finished * 100) if total_finished > 0 else 0
logger.info(f" Attempt {i}: {count:,} pages ({percentage:.1f}%) - Cumulative: {cumulative:,} ({cumulative_percentage:.1f}%)")
2025-06-02 21:10:30 +00:00
# Output rates
if "server_input_tokens_per_sec" in rates:
logger.info(f"Server Input tokens/sec rate: {rates['server_input_tokens_per_sec']:.2f}")
if "server_output_tokens_per_sec" in rates:
logger.info(f"Server Output tokens/sec rate: {rates['server_output_tokens_per_sec']:.2f}")
2025-06-17 17:06:45 +00:00
if "finished_input_tokens_per_sec" in rates:
logger.info(f"Finished Input tokens/sec rate: {rates['finished_input_tokens_per_sec']:.2f}")
if "finished_output_tokens_per_sec" in rates:
logger.info(f"Finished Output tokens/sec rate: {rates['finished_output_tokens_per_sec']:.2f}")
2025-06-02 21:10:30 +00:00
logger.info("=" * 80)
2024-11-15 12:54:45 -08:00
logger.info("Work done")
2024-11-08 08:14:20 -08:00
2024-11-18 11:04:51 -08:00
2024-11-08 08:14:20 -08:00
if __name__ == "__main__":
2025-07-23 03:40:05 +00:00
asyncio.run(main())