olmocr/pdelfin/beakerpipeline.py

861 lines
37 KiB
Python
Raw Normal View History

2024-11-07 18:21:23 +00:00
import logging
import argparse
import boto3
2024-11-07 21:00:51 +00:00
import signal
2024-11-07 18:21:23 +00:00
import os
2024-11-07 21:00:51 +00:00
import sys
import time
2024-11-07 20:16:23 +00:00
import subprocess
2024-11-07 21:08:46 +00:00
import hashlib
2024-11-11 10:24:47 -08:00
import json
2024-11-07 23:24:01 +00:00
import base64
2024-11-08 10:36:09 -08:00
import atexit
2024-11-08 08:14:20 -08:00
import asyncio
import httpx
2024-11-11 11:58:45 -08:00
import datetime
2024-11-08 10:36:09 -08:00
import tempfile
2024-11-13 13:23:29 -08:00
import random
import re
2024-11-18 08:25:36 -08:00
import torch
2024-11-07 18:21:23 +00:00
from tqdm import tqdm
2024-11-07 23:24:01 +00:00
from io import BytesIO
from PIL import Image
2024-11-08 11:04:58 -08:00
from pypdf import PdfReader
2024-11-21 23:23:11 +00:00
from functools import partial, cache
2024-11-11 11:46:49 -08:00
from dataclasses import dataclass
2024-11-18 11:04:51 -08:00
from typing import Optional, Tuple, List, Dict, Set
2024-11-18 11:50:22 -08:00
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
2024-11-07 18:21:23 +00:00
2024-11-18 11:04:51 -08:00
from pdelfin.s3_queue import S3WorkQueue, WorkItem
2024-11-14 14:13:04 -08:00
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
2024-11-07 23:24:01 +00:00
from pdelfin.data.renderpdf import render_pdf_to_base64png
2024-11-21 10:20:58 -08:00
from pdelfin.filter.filter import PdfFilter, Language
2024-11-07 23:24:01 +00:00
from pdelfin.prompts import build_finetuning_prompt, PageResponse
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.check import check_poppler_version
2024-11-12 13:28:39 -08:00
from pdelfin.metrics import MetricsKeeper, WorkerTracker
2024-11-13 12:35:40 -08:00
from pdelfin.version import VERSION
2024-11-11 14:26:15 -08:00
# Initialize logger
2024-11-07 18:21:23 +00:00
logger = logging.getLogger(__name__)
2024-11-11 14:26:15 -08:00
logger.setLevel(logging.DEBUG)
2024-11-12 09:33:53 -08:00
logger.propagate = False
sglang_logger = logging.getLogger("sglang")
sglang_logger.propagate = False
2024-11-11 14:26:15 -08:00
file_handler = logging.FileHandler('beakerpipeline-debug.log', mode='a')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
2024-11-12 09:33:53 -08:00
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
2024-11-11 14:26:15 -08:00
# Add handlers to the logger
logger.addHandler(file_handler)
2024-11-12 09:33:53 -08:00
logger.addHandler(console_handler)
sglang_logger.addHandler(file_handler)
2024-11-07 18:21:23 +00:00
# Quiet logs from pypdf
logging.getLogger("pypdf").setLevel(logging.ERROR)
2024-11-11 13:09:09 -08:00
# Global s3 clients fo the whole script, we have two separate ones in case your workspace and your pdfs are in different accounts
workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3')
2024-11-11 13:09:09 -08:00
# Global variables for token statistics
2024-11-12 12:56:35 -08:00
metrics = MetricsKeeper(window=60*5)
2024-11-12 13:28:39 -08:00
tracker = WorkerTracker()
2024-11-11 13:09:09 -08:00
2024-11-11 15:35:18 -08:00
# Process pool for offloading cpu bound work, like calculating anchor texts
process_pool = ProcessPoolExecutor()
2024-11-08 11:04:58 -08:00
2024-11-21 23:23:11 +00:00
# Filter object, cached so it will only get loaded when/if you need it
get_pdf_filter = cache(lambda: PdfFilter(languages_to_keep={Language.ENGLISH, None}, apply_download_spam_check=True, apply_form_check=True))
2024-11-21 10:20:58 -08:00
SGLANG_SERVER_PORT = 30024
2024-11-14 13:13:27 -08:00
2024-11-11 11:46:49 -08:00
@dataclass(frozen=True)
class PageResult:
s3_path: str
page_num: int
response: PageResponse
2024-11-12 08:34:25 -08:00
input_tokens: int
output_tokens: int
2024-11-19 14:59:20 -08:00
is_fallback: bool
2024-11-11 13:09:09 -08:00
2024-11-07 18:21:23 +00:00
2024-11-08 09:59:27 -08:00
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
2024-11-11 13:09:09 -08:00
MAX_TOKENS = 3000
2024-11-07 23:24:01 +00:00
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
2024-11-08 09:59:27 -08:00
# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
image_base64 = asyncio.to_thread(render_pdf_to_base64png, local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
2024-11-11 14:26:15 -08:00
# GET ANCHOR TEXT IS NOT THREAD SAFE!! Ahhhh..... don't try to do it
2024-11-11 15:35:18 -08:00
# and it's also CPU bound, so it needs to run in a process pool
loop = asyncio.get_running_loop()
anchor_text = loop.run_in_executor(process_pool, partial(get_anchor_text, pdf_engine="pdfreport", target_length=target_anchor_text_len), local_pdf_path, page)
2024-11-08 09:59:27 -08:00
2024-11-11 15:35:18 -08:00
image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text)
2024-11-07 23:24:01 +00:00
if image_rotation != 0:
image_bytes = base64.b64decode(image_base64)
with Image.open(BytesIO(image_bytes)) as img:
rotated_img = img.rotate(-image_rotation, expand=True)
# Save the rotated image to a bytes buffer
buffered = BytesIO()
rotated_img.save(buffered, format="PNG")
# Encode the rotated image back to base64
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return {
2024-11-08 11:04:58 -08:00
"model": "Qwen/Qwen2-VL-7B-Instruct",
"messages": [
2024-11-07 23:24:01 +00:00
{
"role": "user",
"content": [
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
],
}
],
2024-11-08 11:04:58 -08:00
"max_tokens": MAX_TOKENS,
2024-11-07 23:24:01 +00:00
"temperature": 0.8
}
async def process_page(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
2024-11-14 13:13:27 -08:00
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries
2024-11-14 09:02:49 -08:00
exponential_backoffs = 0
2024-11-15 10:03:26 -08:00
local_anchor_text_len = args.target_anchor_text_len
2024-11-18 08:25:36 -08:00
local_image_rotation = 0
attempt = 0
2024-11-12 13:28:39 -08:00
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "started")
2024-11-11 11:58:45 -08:00
while attempt < MAX_RETRIES:
2024-11-11 14:38:26 -08:00
query = await build_page_query(
pdf_local_path,
page_num,
args.target_longest_image_dim,
2024-11-18 08:25:36 -08:00
local_anchor_text_len,
image_rotation=local_image_rotation
2024-11-11 14:38:26 -08:00
)
2024-11-11 11:46:49 -08:00
2024-11-11 14:38:26 -08:00
try:
response = await session.post(COMPLETION_URL, json=query)
if response.status_code == 400:
raise ValueError(f"Got BadRequestError from server: {response.text}, skipping this response")
elif response.status_code == 500:
raise ValueError(f"Got InternalServerError from server: {response.text}, skipping this response")
else:
response.raise_for_status()
2024-11-18 09:03:24 -08:00
base_response_data = response.json()
2024-11-15 10:03:26 -08:00
2024-11-15 12:48:36 -08:00
if base_response_data["usage"]["total_tokens"] > args.model_max_context:
local_anchor_text_len = max(1, local_anchor_text_len // 2)
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_s3_path}-{page_num}")
raise ValueError(f"Response exceeded model_max_context, cannot use this response")
metrics.add_metrics(sglang_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
sglang_output_tokens=base_response_data["usage"].get("completion_tokens", 0))
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)
2024-11-18 08:29:32 -08:00
if not page_response.is_rotation_valid and attempt < MAX_RETRIES - 1:
logger.info(f"Got invalid_page rotation for {pdf_s3_path}-{page_num} attempt {attempt}, retrying with {page_response.rotation_correction} rotation")
local_image_rotation = page_response.rotation_correction
raise ValueError(f"invalid_page rotation for {pdf_s3_path}-{page_num}")
2024-11-15 12:48:36 -08:00
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "finished")
return PageResult(
pdf_s3_path,
page_num,
page_response,
input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
2024-11-19 14:59:20 -08:00
output_tokens=base_response_data["usage"].get("completion_tokens", 0),
is_fallback=False,
2024-11-15 12:48:36 -08:00
)
except (httpx.TransportError, asyncio.TimeoutError) as e:
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
# Now we want to do exponential backoff, and not count this as an actual page retry
# Page retrys are supposed to be for fixing bad results from the model, but actual requests to sglang
# are supposed to work. Probably this means that the server is just restarting
2024-11-14 09:02:49 -08:00
sleep_delay = 10 * (2 ** exponential_backoffs)
exponential_backoffs += 1
logger.info(f"Sleeping for {sleep_delay} seconds on {pdf_s3_path}-{page_num} to allow server restart")
await asyncio.sleep(sleep_delay)
2024-11-14 12:06:13 -08:00
except asyncio.CancelledError:
logger.info(f"Process page {pdf_s3_path}-{page_num} cancelled")
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "cancelled")
raise
2024-11-14 14:13:04 -08:00
except json.JSONDecodeError as e:
logger.warning(f"JSON decode error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
attempt += 1
2024-11-15 10:03:26 -08:00
except ValueError as e:
logger.warning(f"ValueError on attempt {attempt} for {pdf_s3_path}-{page_num}: {type(e)} - {e}")
attempt += 1
2024-11-11 14:38:26 -08:00
except Exception as e:
logger.exception(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}: {type(e)} - {e}")
attempt += 1
2024-11-11 14:38:26 -08:00
2024-11-14 14:13:04 -08:00
logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.")
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "errored")
2024-11-19 14:59:20 -08:00
return PageResult(
pdf_s3_path,
page_num,
PageResponse(natural_text=get_anchor_text(pdf_local_path, page_num, pdf_engine="pdftotext"),
2024-11-21 11:08:42 -08:00
primary_language=None, is_rotation_valid=True, rotation_correction=0, is_table=False, is_diagram=False),
2024-11-19 14:59:20 -08:00
input_tokens=0,
output_tokens=0,
is_fallback=True
)
2024-11-08 09:59:27 -08:00
async def process_pdf(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_path: str):
2024-11-08 10:19:00 -08:00
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Switch to aioboto3 or something
2024-11-14 14:13:04 -08:00
data = await asyncio.to_thread(lambda: get_s3_bytes_with_backoff(pdf_s3, pdf_s3_path))
2024-11-08 10:19:00 -08:00
tf.write(data)
tf.flush()
2024-11-08 09:59:27 -08:00
2024-11-14 08:55:20 -08:00
try:
reader = PdfReader(tf.name)
num_pages = reader.get_num_pages()
except:
logger.exception(f"Could not count number of pages for {pdf_s3_path}, aborting document")
return None
2024-11-08 10:19:00 -08:00
2024-11-12 13:44:20 -08:00
logger.info(f"Got {num_pages} pages to do for {pdf_s3_path} in worker {worker_id}")
2024-11-21 23:23:11 +00:00
if args.apply_filter and get_pdf_filter().filter_out_pdf(tf.name):
2024-11-21 10:20:58 -08:00
logger.info(f"Filtering out pdf {pdf_s3_path}")
return None
2024-11-08 10:19:00 -08:00
# List to hold the tasks for processing each page
page_tasks = []
2024-11-14 12:06:13 -08:00
page_results = []
2024-11-08 10:19:00 -08:00
try:
2024-11-14 12:06:13 -08:00
async with asyncio.TaskGroup() as tg:
for page_num in range(1, num_pages + 1):
task = tg.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num))
page_tasks.append(task)
2024-11-14 12:06:13 -08:00
# Collect the results from the entire task group, assuming no exceptions
page_results = [task.result() for task in page_tasks]
2024-11-19 14:59:20 -08:00
num_fallback_pages = sum(page_result.is_fallback for page_result in page_results)
if num_fallback_pages / num_pages > args.max_page_error_rate:
logger.error(f"Document {pdf_s3_path} has {num_fallback_pages} fallback pages out of {num_pages} exceeding max_page_error_rate of {args.max_page_error_rate}, discarding document.")
return None
elif num_fallback_pages > 0:
logger.warning(f"Document {pdf_s3_path} processed with {num_fallback_pages} fallback pages out of {num_pages}, proceeding to build Dolma document.")
return build_dolma_document(pdf_s3_path, page_results)
2024-11-14 12:06:13 -08:00
except Exception as e:
logger.exception(f"Exception in process_pdf for {pdf_s3_path}: {e}")
# You can't build a dolma doc with even 1 failed page, so just get out of here
# However, you don't want to propagate an exception higher up and cancel the entire work_group
return None
2024-11-14 12:06:13 -08:00
def build_dolma_document(pdf_s3_path, page_results):
# Build the document text and page spans
document_text = ""
pdf_page_spans = []
current_char_pos = 0
for index, page_result in enumerate(page_results):
if page_result.response.natural_text is not None:
content = page_result.response.natural_text + ("\n" if index < len(page_results) - 1 else "")
else:
content = ""
start_pos = current_char_pos
document_text += content
current_char_pos = len(document_text)
pdf_page_spans.append([start_pos, current_char_pos, page_result.page_num])
if not document_text:
logger.info(f"No document text for {pdf_s3_path}")
2024-11-14 12:06:13 -08:00
return None # Return None if the document text is empty
# Build the Dolma document
metadata = {
"Source-File": pdf_s3_path,
2024-11-20 10:42:26 -08:00
"pdelfin-version": VERSION,
2024-11-14 12:06:13 -08:00
"pdf-total-pages": len(page_results),
"total-input-tokens": sum(page.input_tokens for page in page_results),
2024-11-19 15:11:02 -08:00
"total-output-tokens": sum(page.output_tokens for page in page_results),
"total-fallback-pages": sum(page.is_fallback for page in page_results),
2024-11-14 12:06:13 -08:00
}
2024-11-08 10:19:00 -08:00
2024-11-14 12:06:13 -08:00
id_ = hashlib.sha1(document_text.encode()).hexdigest()
dolma_doc = {
"id": id_,
"text": document_text,
"source": "pdelfin",
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
"attributes": {
"pdf_page_numbers": pdf_page_spans
2024-11-08 09:59:27 -08:00
}
2024-11-14 12:06:13 -08:00
}
return dolma_doc
2024-11-08 09:59:27 -08:00
2024-11-18 11:04:51 -08:00
async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
2024-11-08 08:14:20 -08:00
while True:
2024-11-18 11:04:51 -08:00
# Wait until allowed to proceed
await semaphore.acquire()
2024-11-11 13:09:09 -08:00
2024-11-18 11:04:51 -08:00
work_item = await work_queue.get_work()
2024-11-13 13:05:57 -08:00
2024-11-18 11:04:51 -08:00
if work_item is None:
logger.info(f"Worker {worker_id} exiting due to empty queue")
semaphore.release()
break
2024-11-18 11:04:51 -08:00
logger.info(f"Worker {worker_id} processing work item {work_item.hash}")
await tracker.clear_work(worker_id)
2024-11-13 13:05:57 -08:00
2024-11-18 11:04:51 -08:00
try:
async with httpx.AsyncClient(timeout=600, limits=httpx.Limits(max_keepalive_connections=0, max_connections=1000)) as session:
2024-11-14 13:13:27 -08:00
async with asyncio.TaskGroup() as tg:
2024-11-18 11:04:51 -08:00
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in work_item.s3_work_paths]
logger.info(f"Created all tasks for {work_item.hash}")
2024-11-18 11:04:51 -08:00
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
2024-11-18 11:04:51 -08:00
logger.info(f"Closed ClientSession for {work_item.hash}")
2024-11-14 12:06:13 -08:00
dolma_docs = []
for task in dolma_tasks:
try:
result = task.result()
except:
# some dolma doc creations may have failed
pass
if result is not None:
dolma_docs.append(result)
2024-11-18 11:04:51 -08:00
logger.info(f"Got {len(dolma_docs)} docs for {work_item.hash}")
2024-11-11 13:09:09 -08:00
# Write the Dolma documents to a local temporary file in JSONL format
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
for doc in dolma_docs:
tf.write(json.dumps(doc))
tf.write('\n')
tf.flush()
# Define the output S3 path using the work_hash
2024-11-18 11:04:51 -08:00
output_s3_path = os.path.join(args.workspace, 'results', f'output_{work_item.hash}.jsonl')
2024-11-11 13:09:09 -08:00
bucket, key = parse_s3_path(output_s3_path)
workspace_s3.upload_file(tf.name, bucket, key)
2024-11-12 09:33:53 -08:00
# Update finished token counts from successful documents
2024-11-12 12:56:35 -08:00
metrics.add_metrics(finished_input_tokens=sum(doc["metadata"]["total-input-tokens"] for doc in dolma_docs),
finished_output_tokens=sum(doc["metadata"]["total-output-tokens"] for doc in dolma_docs))
2024-11-11 13:09:09 -08:00
# Update last batch time
2024-11-11 13:31:14 -08:00
last_batch_time = time.perf_counter()
2024-11-11 13:09:09 -08:00
except Exception as e:
2024-11-18 11:04:51 -08:00
logger.exception(f"Exception occurred while processing work_hash {work_item.hash}: {e}")
2024-11-11 13:09:09 -08:00
finally:
2024-11-18 11:04:51 -08:00
await work_queue.mark_done(work_item)
semaphore.release()
2024-11-08 08:14:20 -08:00
async def sglang_server_task(args, semaphore):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
2024-11-12 15:18:04 -08:00
download_directory(args.model, model_cache_dir)
2024-11-11 10:24:47 -08:00
# Check the rope config and make sure it's got the proper key
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
config_data = json.load(cfin)
if "rope_type" in config_data["rope_scaling"]:
del config_data["rope_scaling"]["rope_type"]
config_data["rope_scaling"]["type"] = "mrope"
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
json.dump(config_data, cfout)
2024-11-18 08:25:36 -08:00
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB
2024-11-18 08:29:32 -08:00
mem_fraction_arg = ["--mem-fraction-static", "0.80"] if gpu_memory < 60 else []
2024-11-18 08:25:36 -08:00
cmd = [
"python3",
"-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
2024-11-15 13:18:13 -08:00
# "--context-length", str(args.model_max_context), # Commented out due to crashes
2024-11-14 13:13:27 -08:00
"--port", str(SGLANG_SERVER_PORT),
2024-11-14 09:55:37 -08:00
"--log-level-http", "warning",
2024-11-18 08:25:36 -08:00
]
cmd.extend(mem_fraction_arg)
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
2024-11-15 13:18:13 -08:00
)
2024-11-15 13:18:13 -08:00
# Ensure the subprocess is terminated on exit
2024-11-11 10:24:47 -08:00
def _kill_proc():
proc.terminate()
atexit.register(_kill_proc)
2024-11-08 10:36:09 -08:00
2024-11-15 13:18:13 -08:00
# Shared variables between tasks
last_running_req, last_queue_req = 0, 0
server_printed_ready_message = False
last_semaphore_release = time.time()
2024-11-15 13:18:13 -08:00
async def process_line(line):
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
2024-11-12 09:33:53 -08:00
sglang_logger.info(line)
2024-11-16 08:16:11 -08:00
if "Detected errors during sampling" in line:
logger.error("Cannot continue, sampling errors detected, model is probably corrupt")
sys.exit(1)
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
server_printed_ready_message = True
2024-11-18 11:04:51 -08:00
last_semaphore_release = time.time()
match = re.search(r'#running-req: (\d+)', line)
if match:
last_running_req = int(match.group(1))
2024-11-15 13:18:13 -08:00
2024-11-12 08:34:25 -08:00
match = re.search(r'#queue-req: (\d+)', line)
if match:
2024-11-18 08:25:36 -08:00
last_queue_req = int(match.group(1))
logger.info(f"sglang running req: {last_running_req} queue req: {last_queue_req}")
async def read_stream(stream):
while True:
line = await stream.readline()
if not line:
break
line = line.decode('utf-8').rstrip()
await process_line(line)
2024-11-15 13:18:13 -08:00
async def timeout_task():
2024-11-18 08:25:36 -08:00
nonlocal last_running_req, last_queue_req, last_semaphore_release
2024-11-15 13:18:13 -08:00
try:
while True:
2024-11-15 13:19:23 -08:00
await asyncio.sleep(1)
if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
2024-11-15 13:18:13 -08:00
semaphore.release()
last_semaphore_release = time.time()
2024-11-18 08:25:36 -08:00
logger.info("Semaphore released, allowing a worker to proceed.")
2024-11-15 13:18:13 -08:00
except asyncio.CancelledError:
pass # Clean up if the task is cancelled
# Start tasks to read stdout, stderr, and handle timeout logic
stdout_task = asyncio.create_task(read_stream(proc.stdout))
stderr_task = asyncio.create_task(read_stream(proc.stderr))
2024-11-15 13:19:23 -08:00
timeout_task = asyncio.create_task(timeout_task())
await proc.wait()
2024-11-15 13:19:23 -08:00
timeout_task.cancel()
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
2024-11-08 10:36:09 -08:00
async def sglang_server_host(args, semaphore):
while True:
await sglang_server_task(args, semaphore)
2024-11-12 15:18:04 -08:00
logger.warning("SGLang server task ended")
2024-11-08 10:36:09 -08:00
async def sglang_server_ready():
2024-11-08 11:38:56 -08:00
max_attempts = 300
2024-11-08 10:19:00 -08:00
delay_sec = 1
2024-11-14 13:13:27 -08:00
url = f'http://localhost:{SGLANG_SERVER_PORT}/v1/models'
2024-11-08 10:19:00 -08:00
for attempt in range(1, max_attempts + 1):
try:
async with httpx.AsyncClient() as session:
response = await session.get(url)
if response.status_code == 200:
logger.info("sglang server is ready.")
return
else:
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
2024-11-08 10:19:00 -08:00
except Exception as e:
2024-11-11 14:26:15 -08:00
logger.warning(f"Attempt {attempt}: {e}")
2024-11-08 10:36:09 -08:00
2024-11-08 10:19:00 -08:00
await asyncio.sleep(delay_sec)
raise Exception("sglang server did not become ready after waiting.")
2024-11-12 13:28:39 -08:00
2024-11-18 11:04:51 -08:00
async def metrics_reporter(work_queue):
2024-11-12 12:56:35 -08:00
while True:
2024-11-12 13:28:39 -08:00
# Leading newlines preserve table formatting in logs
2024-11-18 11:04:51 -08:00
logger.info(f"Queue remaining: {work_queue.size}")
2024-11-12 13:28:39 -08:00
logger.info("\n" + str(metrics))
logger.info("\n" + str(await tracker.get_status_table()))
2024-11-12 12:56:35 -08:00
await asyncio.sleep(10)
2024-11-13 10:25:35 -08:00
2024-11-12 15:56:51 -08:00
def submit_beaker_job(args):
2024-11-13 08:00:14 -08:00
from beaker import (
Beaker,
Constraints,
DataMount,
DataSource,
EnvVar,
ExperimentSpec,
ImageSource,
Priority,
ResultSpec,
SecretNotFound,
TaskContext,
TaskResources,
TaskSpec,
)
b = Beaker.from_env(default_workspace=args.beaker_workspace)
account = b.account.whoami()
2024-11-13 11:26:46 -08:00
owner = account.name
2024-11-13 12:35:40 -08:00
beaker_image = f"jakep/pdelfin-inference-{VERSION}"
2024-11-13 08:00:14 -08:00
2024-11-13 09:35:34 -08:00
task_name = f"pdelfin-{os.path.basename(args.workspace.rstrip('/'))}"
2024-11-13 08:00:14 -08:00
2024-11-19 11:48:45 -08:00
# Take out --beaker flag so the workers will just run things
2024-11-13 10:25:35 -08:00
args_list = [arg for arg in sys.argv[1:] if arg != "--beaker"]
2024-11-13 08:00:14 -08:00
2024-11-19 11:48:45 -08:00
# Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally
args_list = [arg for i, arg in enumerate(args_list) if not (arg.startswith("--pdfs") or (i > 0 and args_list[i-1] == "--pdfs"))]
2024-11-13 11:26:46 -08:00
try:
b.secret.get(f"{owner}-WEKA_ACCESS_KEY_ID", args.beaker_workspace)
b.secret.get(f"{owner}-WEKA_SECRET_ACCESS_KEY", args.beaker_workspace)
b.secret.get(f"{owner}-AWS_CREDENTIALS_FILE", args.beaker_workspace)
except SecretNotFound:
print(f"Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace {args.beaker_workspace}? [y/n]")
if input().strip().lower() != "y":
print("Exiting...")
sys.exit(1)
b.secret.write(f"{owner}-WEKA_ACCESS_KEY_ID", os.environ.get("WEKA_ACCESS_KEY_ID", ""), args.beaker_workspace)
b.secret.write(f"{owner}-WEKA_SECRET_ACCESS_KEY", os.environ.get("WEKA_SECRET_ACCESS_KEY", ""), args.beaker_workspace)
b.secret.write(f"{owner}-AWS_CREDENTIALS_FILE", open(os.path.join(os.path.expanduser('~'), '.aws', 'credentials')).read(), args.beaker_workspace)
2024-11-18 13:07:27 -08:00
try:
b.secret.get(f"OE_DATA_GCS_SA_KEY", args.beaker_workspace)
except SecretNotFound:
print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):")
lines = []
prev_empty = False
for line in iter(input, None):
if not line and prev_empty:
break
prev_empty = not line
lines.append(line)
gcs_sa_key = "\n".join(lines[:-1]).strip() # Remove the last empty line
if gcs_sa_key:
b.secret.write(f"OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
2024-11-13 11:26:46 -08:00
2024-11-13 08:00:14 -08:00
# Create the experiment spec
experiment_spec = ExperimentSpec(
budget="ai2/oe-data",
description=task_name,
tasks=[
TaskSpec(
name=task_name,
propagate_failure=False,
propagate_preemption=False,
2024-11-14 09:02:49 -08:00
replicas=args.beaker_gpus,
2024-11-13 08:00:14 -08:00
context=TaskContext(
2024-11-14 09:10:28 -08:00
priority=Priority(args.beaker_priority),
2024-11-13 08:00:14 -08:00
preemptible=True,
),
image=ImageSource(beaker=beaker_image),
command=["python", "-m", "pdelfin.beakerpipeline"] + args_list,
env_vars=[
EnvVar(name="BEAKER_JOB_NAME", value=task_name),
2024-11-13 11:26:46 -08:00
EnvVar(name="OWNER", value=owner),
EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
2024-11-18 13:20:28 -08:00
EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret=f"OE_DATA_GCS_SA_KEY"),
2024-11-13 08:00:14 -08:00
],
2024-11-14 09:02:49 -08:00
resources=TaskResources(gpu_count=1),
2024-11-15 13:30:27 -08:00
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
2024-11-13 08:00:14 -08:00
result=ResultSpec(path="/noop-results"),
)
],
)
experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace)
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
2024-11-12 15:56:51 -08:00
2024-11-18 11:04:51 -08:00
2024-11-18 11:50:22 -08:00
def print_stats(args):
2024-11-18 07:57:39 -08:00
# Get total work items and completed items
2024-11-18 11:04:51 -08:00
index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "results", "*.jsonl")
2024-11-18 07:57:39 -08:00
done_work_items = expand_s3_glob(workspace_s3, output_glob)
2024-11-19 13:41:32 -08:00
work_queue = {
parts[0]: parts[1:]
for line in download_zstd_csv(workspace_s3, index_file_s3_path)
if (parts := line.strip().split(",")) and line.strip()
}
2024-11-18 07:57:39 -08:00
2024-11-19 13:41:32 -08:00
total_items = len(work_queue)
2024-11-18 07:57:39 -08:00
completed_items = len(done_work_items)
def process_output_file(s3_path):
try:
data = get_s3_bytes(workspace_s3, s3_path)
doc_count = 0
total_input_tokens = 0
total_output_tokens = 0
2024-11-18 11:50:22 -08:00
total_pages = 0
2024-11-20 10:42:26 -08:00
total_fallback_pages = 0
2024-11-18 11:50:22 -08:00
processed_paths = set()
2024-11-18 07:57:39 -08:00
for line in data.decode('utf-8').splitlines():
if line.strip():
doc = json.loads(line)
doc_count += 1
2024-11-20 10:42:26 -08:00
total_input_tokens += doc["metadata"].get("total-input-tokens", 0)
total_output_tokens += doc["metadata"].get("total-output-tokens", 0)
total_pages += doc["metadata"].get("pdf-total-pages", 0)
2024-11-20 23:57:10 +00:00
total_fallback_pages += doc["metadata"].get("total-fallback-pages", 0)
2024-11-18 11:50:22 -08:00
processed_paths.add(doc["metadata"]["Source-File"])
2024-11-18 07:57:39 -08:00
2024-11-20 10:42:26 -08:00
return doc_count, total_input_tokens, total_output_tokens, total_pages, total_fallback_pages, processed_paths
2024-11-18 07:57:39 -08:00
except Exception as e:
logger.warning(f"Error processing {s3_path}: {e}")
2024-11-20 10:42:26 -08:00
return 0, 0, 0, 0, 0, set()
2024-11-18 07:57:39 -08:00
print("\nProcessing output files...")
docs_total = 0
input_tokens_total = 0
output_tokens_total = 0
2024-11-18 11:50:22 -08:00
pages_total = 0
2024-11-20 10:42:26 -08:00
fallback_pages_total = 0
2024-11-18 11:50:22 -08:00
all_processed_paths = set()
original_paths = set()
# First collect all original PDF paths
2024-11-19 13:41:32 -08:00
for done_work_item in done_work_items:
if match := re.search(r"output_(\w+).jsonl", done_work_item):
done_work_hash = match.group(1)
original_paths.update(work_queue[done_work_hash])
2024-11-18 07:57:39 -08:00
2024-11-18 11:50:22 -08:00
with ThreadPoolExecutor() as executor:
2024-11-18 07:57:39 -08:00
futures = {executor.submit(process_output_file, item): item for item in done_work_items}
2024-11-18 11:50:22 -08:00
for future in tqdm(as_completed(futures), total=len(futures)):
2024-11-20 10:42:26 -08:00
doc_count, input_tokens, output_tokens, pages, fallback_pages, processed_paths = future.result()
2024-11-18 07:57:39 -08:00
docs_total += doc_count
input_tokens_total += input_tokens
output_tokens_total += output_tokens
2024-11-18 11:50:22 -08:00
pages_total += pages
2024-11-20 10:42:26 -08:00
fallback_pages_total += fallback_pages
2024-11-18 11:50:22 -08:00
all_processed_paths.update(processed_paths)
skipped_paths = original_paths - all_processed_paths
2024-11-19 13:41:32 -08:00
print(f"\nWork Items Status:")
print(f"Total work items: {total_items:,}")
print(f"Completed items: {completed_items:,}")
print(f"Remaining items: {total_items - completed_items:,}")
2024-11-18 07:57:39 -08:00
print(f"\nResults:")
print(f"Total documents processed: {docs_total:,}")
2024-11-18 11:50:22 -08:00
print(f"Total documents skipped: {len(skipped_paths):,}")
2024-11-20 10:42:26 -08:00
print(f"Total pages on fallback: {fallback_pages_total:,}")
2024-11-18 11:50:22 -08:00
print(f"Total pages processed: {pages_total:,}")
print(f"\nTotal output tokens: {output_tokens_total:,}")
2024-11-20 23:57:10 +00:00
print(f"Projected output tokens: {round((output_tokens_total/max(1, completed_items))*total_items):,}")
2024-11-18 11:50:22 -08:00
print(f"\nAverage pages per doc: {pages_total/max(1,docs_total):,.1f}")
2024-11-18 07:57:39 -08:00
print(f"Average output tokens per doc: {output_tokens_total/max(1,docs_total):,.1f}")
2024-11-18 11:50:22 -08:00
print(f"Average output tokens per page: {output_tokens_total/max(1,pages_total):,.1f}")
2024-11-12 13:28:39 -08:00
2024-11-08 08:14:20 -08:00
async def main():
2024-11-07 18:21:23 +00:00
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
2024-11-14 13:38:58 -08:00
parser.add_argument('--pages_per_group', type=int, default=500, help='Aiming for this many pdf pages per work item group')
parser.add_argument('--max_page_retries', type=int, default=8, help='Max number of times we will retry rendering a page')
2024-11-19 14:59:20 -08:00
parser.add_argument('--max_page_error_rate', type=float, default=0.004, help='Rate of allowable failed pages in a document, 1/250 by default')
2024-11-14 09:55:37 -08:00
parser.add_argument('--workers', type=int, default=8, help='Number of workers to run at a time')
2024-11-21 10:20:58 -08:00
parser.add_argument('--apply_filter', action='store_true', help='Apply basic filtering to English pdfs which are not forms, and not likely seo spam')
2024-11-18 07:57:39 -08:00
parser.add_argument('--stats', action='store_true', help='Instead of running any job, reports some statistics about the current workspace')
2024-11-12 15:56:51 -08:00
# Model parameters
2024-11-07 19:01:45 +00:00
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
"gs://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/",
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"])
parser.add_argument('--model_max_context', type=int, default="8192", help="Maximum context length that the model was fine tuned under")
parser.add_argument('--model_chat_template', type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
2024-11-12 15:56:51 -08:00
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
# Beaker/job running stuff
parser.add_argument('--beaker', action='store_true', help='Submit this job to beaker instead of running locally')
parser.add_argument('--beaker_workspace', help='Beaker workspace to submit to', default='ai2/pdelfin')
2024-11-18 11:50:22 -08:00
parser.add_argument('--beaker_cluster', help='Beaker clusters you want to run on', default=["ai2/jupiter-cirrascale-2", "ai2/pluto-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"])
2024-11-14 08:55:20 -08:00
parser.add_argument('--beaker_gpus', type=int, default=1, help="Number of gpu replicas to run")
2024-11-14 09:10:28 -08:00
parser.add_argument('--beaker_priority', type=str, default="normal", help="Beaker priority level for the job")
args = parser.parse_args()
2024-11-13 13:23:29 -08:00
global workspace_s3, pdf_s3
2024-11-14 08:49:12 -08:00
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
if "BEAKER_JOB_NAME" in os.environ:
sglang_logger.addHandler(console_handler)
2024-11-13 11:26:46 -08:00
cred_path = os.path.join(os.path.expanduser('~'), '.aws', 'credentials')
2024-11-13 12:35:40 -08:00
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
2024-11-13 11:26:46 -08:00
with open(cred_path, "w") as f:
f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
2024-11-18 13:20:28 -08:00
cred_path = os.path.join(os.path.expanduser('~'), '.gcs', 'credentials')
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE"))
2024-11-18 13:58:25 -08:00
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
2024-11-13 14:24:23 -08:00
workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3')
2024-11-13 11:26:46 -08:00
2024-11-07 18:21:23 +00:00
if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.client("s3")
if args.pdf_profile:
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3")
2024-11-07 23:24:01 +00:00
check_poppler_version()
2024-11-07 21:08:46 +00:00
2024-11-18 11:04:51 -08:00
# Create work queue
work_queue = S3WorkQueue(workspace_s3, args.workspace)
2024-11-07 18:21:23 +00:00
if args.pdfs:
2024-11-13 12:59:52 -08:00
logger.info("Got --pdfs argument, going to add to the work queue")
2024-11-18 11:04:51 -08:00
# Expand s3 paths
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
s3_work_paths = expand_s3_glob(pdf_s3, args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
s3_work_paths = list(filter(None, (line.strip() for line in f)))
else:
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
s3_work_paths = set(s3_work_paths)
logger.info(f"Found {len(s3_work_paths):,} total pdf paths to add")
# Estimate average pages per pdf
sample_size = min(100, len(s3_work_paths))
sampled_pdfs = random.sample(list(s3_work_paths), sample_size)
page_counts = []
for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"):
try:
# Download the PDF to a temp file
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
s3_bucket, s3_key = parse_s3_path(pdf)
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
tmp_file.flush()
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
except Exception as e:
logger.warning(f"Failed to read {pdf}: {e}")
if page_counts:
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
else:
logger.warning("Could not read any PDFs to estimate average page count.")
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
items_per_group = max(1, int(args.pages_per_group / avg_pages_per_pdf))
logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
# Now call populate_queue
await work_queue.populate_queue(s3_work_paths, items_per_group)
2024-11-08 08:14:20 -08:00
2024-11-18 07:57:39 -08:00
if args.stats:
print_stats(args)
return
2024-11-12 15:56:51 -08:00
if args.beaker:
2024-11-13 08:00:14 -08:00
submit_beaker_job(args)
2024-11-12 15:56:51 -08:00
return
2024-11-13 12:59:52 -08:00
logger.info(f"Starting pipeline with PID {os.getpid()}")
2024-11-18 11:04:51 -08:00
# Initialize the work queue
await work_queue.initialize_queue()
# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
2024-11-08 10:19:00 -08:00
await sglang_server_ready()
2024-11-07 18:21:23 +00:00
2024-11-14 12:06:13 -08:00
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
2024-11-12 12:56:35 -08:00
2024-11-08 08:14:20 -08:00
# Create worker tasks to process the queue concurrently.
worker_tasks = []
2024-11-08 08:14:20 -08:00
for i in range(args.workers):
2024-11-12 09:33:53 -08:00
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
worker_tasks.append(task)
2024-11-07 18:21:23 +00:00
2024-11-18 11:04:51 -08:00
# Wait for all worker tasks to finish
await asyncio.gather(*worker_tasks)
2024-11-08 08:14:20 -08:00
# Wait for server to stop
2024-11-15 12:54:45 -08:00
process_pool.shutdown(wait=False)
sglang_server.cancel()
2024-11-12 12:56:35 -08:00
metrics_task.cancel()
2024-11-15 12:54:45 -08:00
logger.info("Work done")
2024-11-08 08:14:20 -08:00
2024-11-18 11:04:51 -08:00
2024-11-08 08:14:20 -08:00
if __name__ == "__main__":
2024-11-15 13:21:55 -08:00
asyncio.run(main())
2024-11-07 18:21:23 +00:00
2024-11-07 21:00:51 +00:00
# TODO
2024-11-18 04:54:12 +00:00
# - Sglang commit a fix for the context length issue
2024-11-18 19:55:26 -08:00
# - aiohttp repro and bug report
2024-11-18 04:54:12 +00:00
# - Get a solid benchmark on the stream vs non stream approach
2024-11-18 12:44:34 -08:00