olmocr/pdelfin/beakerpipeline.py

700 lines
28 KiB
Python
Raw Normal View History

2024-11-07 18:21:23 +00:00
import logging
import argparse
import boto3
2024-11-07 21:00:51 +00:00
import signal
2024-11-07 18:21:23 +00:00
import os
2024-11-07 21:00:51 +00:00
import sys
import time
2024-11-07 20:16:23 +00:00
import subprocess
2024-11-07 21:08:46 +00:00
import hashlib
2024-11-11 10:24:47 -08:00
import json
2024-11-07 23:24:01 +00:00
import base64
2024-11-08 10:36:09 -08:00
import atexit
2024-11-08 08:14:20 -08:00
import asyncio
2024-11-08 10:19:00 -08:00
import aiohttp
2024-11-11 11:58:45 -08:00
import datetime
2024-11-08 10:36:09 -08:00
import tempfile
import re
2024-11-07 18:21:23 +00:00
from tqdm import tqdm
2024-11-07 23:24:01 +00:00
from io import BytesIO
from PIL import Image
2024-11-08 11:04:58 -08:00
from pypdf import PdfReader
2024-11-11 15:35:18 -08:00
from functools import partial
2024-11-11 11:46:49 -08:00
from dataclasses import dataclass
from typing import Optional
2024-11-11 15:35:18 -08:00
from concurrent.futures import ProcessPoolExecutor
2024-11-07 18:21:23 +00:00
2024-11-08 10:36:09 -08:00
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
2024-11-07 23:24:01 +00:00
from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts import build_finetuning_prompt, PageResponse
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.check import check_poppler_version
2024-11-12 13:28:39 -08:00
from pdelfin.metrics import MetricsKeeper, WorkerTracker
2024-11-13 12:35:40 -08:00
from pdelfin.version import VERSION
2024-11-11 14:26:15 -08:00
# Initialize logger
2024-11-07 18:21:23 +00:00
logger = logging.getLogger(__name__)
2024-11-11 14:26:15 -08:00
logger.setLevel(logging.DEBUG)
2024-11-12 09:33:53 -08:00
logger.propagate = False
sglang_logger = logging.getLogger("sglang")
sglang_logger.propagate = False
2024-11-11 14:26:15 -08:00
file_handler = logging.FileHandler('beakerpipeline-debug.log', mode='a')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
2024-11-12 09:33:53 -08:00
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
2024-11-11 14:26:15 -08:00
# Add handlers to the logger
logger.addHandler(file_handler)
2024-11-12 09:33:53 -08:00
logger.addHandler(console_handler)
sglang_logger.addHandler(file_handler)
2024-11-07 18:21:23 +00:00
# Quiet logs from pypdf
logging.getLogger("pypdf").setLevel(logging.ERROR)
2024-11-11 13:09:09 -08:00
# Global s3 clients fo the whole script, we have two separate ones in case your workspace and your pdfs are in different accounts
workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3')
2024-11-11 13:09:09 -08:00
# Global variables for token statistics
2024-11-12 12:56:35 -08:00
metrics = MetricsKeeper(window=60*5)
2024-11-12 13:28:39 -08:00
tracker = WorkerTracker()
2024-11-11 13:09:09 -08:00
2024-11-11 15:35:18 -08:00
# Process pool for offloading cpu bound work, like calculating anchor texts
process_pool = ProcessPoolExecutor()
2024-11-08 11:04:58 -08:00
2024-11-11 11:46:49 -08:00
@dataclass(frozen=True)
class PageResult:
s3_path: str
page_num: int
response: PageResponse
2024-11-12 08:34:25 -08:00
input_tokens: int
output_tokens: int
2024-11-11 13:09:09 -08:00
2024-11-07 18:21:23 +00:00
2024-11-08 09:59:27 -08:00
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
2024-11-11 13:09:09 -08:00
MAX_TOKENS = 3000
2024-11-07 23:24:01 +00:00
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
2024-11-08 09:59:27 -08:00
# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
image_base64 = asyncio.to_thread(render_pdf_to_base64png, local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
2024-11-11 14:26:15 -08:00
# GET ANCHOR TEXT IS NOT THREAD SAFE!! Ahhhh..... don't try to do it
2024-11-11 15:35:18 -08:00
# and it's also CPU bound, so it needs to run in a process pool
loop = asyncio.get_running_loop()
anchor_text = loop.run_in_executor(process_pool, partial(get_anchor_text, pdf_engine="pdfreport", target_length=target_anchor_text_len), local_pdf_path, page)
2024-11-08 09:59:27 -08:00
2024-11-11 15:35:18 -08:00
image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text)
2024-11-07 23:24:01 +00:00
if image_rotation != 0:
image_bytes = base64.b64decode(image_base64)
with Image.open(BytesIO(image_bytes)) as img:
rotated_img = img.rotate(-image_rotation, expand=True)
# Save the rotated image to a bytes buffer
buffered = BytesIO()
rotated_img.save(buffered, format="PNG")
# Encode the rotated image back to base64
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return {
2024-11-08 11:04:58 -08:00
"model": "Qwen/Qwen2-VL-7B-Instruct",
"messages": [
2024-11-07 23:24:01 +00:00
{
"role": "user",
"content": [
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
],
}
],
2024-11-08 11:04:58 -08:00
"max_tokens": MAX_TOKENS,
2024-11-07 23:24:01 +00:00
"temperature": 0.8
}
2024-11-07 21:08:46 +00:00
def compute_workgroup_sha1(work_group: list[str]) -> str:
sha1 = hashlib.sha1()
# Ensure consistent ordering by sorting the list
for pdf in sorted(work_group):
sha1.update(pdf.encode('utf-8'))
return sha1.hexdigest()
2024-11-08 08:14:20 -08:00
async def populate_pdf_work_queue(args):
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
if args.pdfs.startswith("s3://"):
logger.info(f"Expanding s3 glob at {args.pdfs}")
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
elif os.path.exists(args.pdfs):
logger.info(f"Loading file at {args.pdfs}")
with open(args.pdfs, "r") as f:
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
else:
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
all_pdfs = set(all_pdfs)
logger.info(f"Found {len(all_pdfs):,} total pdf paths")
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
# Parse existing work items into groups
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_pdfs = parts[1:]
existing_groups[group_hash] = group_pdfs
existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs)
logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")
# Remove existing PDFs from all_pdfs
new_pdfs = all_pdfs - existing_pdf_set
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")
# Group the new PDFs into chunks of group_size
# TODO: Figure out the group size automatically by sampling a few pdfs, and taking the mean/median number of pages, etc.
new_groups = []
current_group = []
for pdf in sorted(new_pdfs): # Sort for consistency
current_group.append(pdf)
if len(current_group) == args.group_size:
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = compute_workgroup_sha1(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine existing groups with new groups
combined_groups = existing_groups.copy()
for group_hash, group_pdfs in new_groups:
combined_groups[group_hash] = group_pdfs
2024-11-07 13:26:42 -08:00
2024-11-08 08:14:20 -08:00
# Prepare lines to write back
combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()]
# Upload the combined work items back to S3
if new_groups:
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
logger.info("Completed adding new PDFs.")
async def load_pdf_work_queue(args) -> asyncio.Queue:
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
output_glob = f"{args.workspace}/dolma_documents/output_*.jsonl"
2024-11-08 08:14:20 -08:00
# Define the two blocking I/O operations
download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path)
expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob)
2024-11-08 08:14:20 -08:00
# Run both tasks concurrently
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
2024-11-08 08:14:20 -08:00
# Process the work queue lines
work_queue = {
parts[0]: parts[1:]
for line in work_queue_lines
if (parts := line.strip().split(",")) and line.strip()
}
# Extract done work hashes
done_work_hashes = {
os.path.basename(item)[len('output_'):-len('.jsonl')]
for item in done_work_items
if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl')
}
# Determine remaining work
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_work_queue = {
hash_: work_queue[hash_]
for hash_ in remaining_work_hashes
}
2024-11-08 08:14:20 -08:00
# Populate the asyncio.Queue with remaining work
queue = asyncio.Queue()
for work, pdfs in remaining_work_queue.items():
await queue.put((work, pdfs))
2024-11-08 08:14:20 -08:00
return queue
2024-11-08 10:19:00 -08:00
2024-11-12 13:28:39 -08:00
async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
2024-11-08 11:38:56 -08:00
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
2024-11-11 14:38:26 -08:00
MAX_RETRIES = 3
attempt = 0
2024-11-12 13:28:39 -08:00
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "started")
2024-11-11 11:58:45 -08:00
while attempt < MAX_RETRIES:
2024-11-11 14:38:26 -08:00
query = await build_page_query(
pdf_local_path,
page_num,
args.target_longest_image_dim,
args.target_anchor_text_len
)
2024-11-11 11:46:49 -08:00
2024-11-11 14:38:26 -08:00
try:
async with session.post(COMPLETION_URL, json=query) as response:
response.raise_for_status()
base_response_data = await response.json()
2024-11-12 09:33:53 -08:00
2024-11-12 12:56:35 -08:00
metrics.add_metrics(sglang_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
sglang_output_tokens=base_response_data["usage"].get("completion_tokens", 0))
2024-11-11 14:38:26 -08:00
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)
2024-11-12 13:28:39 -08:00
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "finished")
2024-11-11 14:38:26 -08:00
return PageResult(
pdf_s3_path,
page_num,
page_response,
2024-11-12 08:34:25 -08:00
input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
output_tokens=base_response_data["usage"].get("completion_tokens", 0)
2024-11-11 14:38:26 -08:00
)
except aiohttp.ClientError as e:
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
# Now we want to do exponential backoff, and not count this as an actual page retry
# Page retrys are supposed to be for fixing bad results from the model, but actual requests to sglang
# are supposed to work. Probably this means that the server is just restarting
logger.info(f"Sleeping for 5 seconds on {pdf_s3_path}-{page_num} to allow server restart")
await asyncio.sleep(5)
2024-11-11 14:38:26 -08:00
except json.JSONDecodeError as e:
2024-11-11 15:35:18 -08:00
logger.warning(f"JSON decode error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
attempt += 1
2024-11-11 14:38:26 -08:00
except Exception as e:
logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
attempt += 1
2024-11-11 14:38:26 -08:00
if attempt >= MAX_RETRIES:
2024-11-11 15:35:18 -08:00
logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.")
2024-11-12 13:44:20 -08:00
await tracker.track_work(worker_id, f"{pdf_s3_path}-{page_num}", "errored")
2024-11-11 15:35:18 -08:00
raise ValueError(f"Could not process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts")
2024-11-08 09:59:27 -08:00
async def process_pdf(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str):
2024-11-08 10:19:00 -08:00
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Switch to aioboto3 or something
data = await asyncio.to_thread(lambda: get_s3_bytes(pdf_s3, pdf_s3_path))
tf.write(data)
tf.flush()
2024-11-08 09:59:27 -08:00
2024-11-08 10:19:00 -08:00
reader = PdfReader(tf.name)
num_pages = reader.get_num_pages()
2024-11-12 13:44:20 -08:00
logger.info(f"Got {num_pages} pages to do for {pdf_s3_path} in worker {worker_id}")
2024-11-08 10:19:00 -08:00
# List to hold the tasks for processing each page
page_tasks = []
for page_num in range(1, num_pages + 1):
# Create a task for each page
task = asyncio.create_task(process_page(args, session, worker_id, pdf_s3_path, tf.name, page_num))
page_tasks.append(task)
2024-11-08 10:19:00 -08:00
# Gather results from all page processing tasks
try:
page_results: list[PageResult] = await asyncio.gather(*page_tasks)
except:
logger.exception(f"Could not load page for {pdf_s3_path}, aborting document")
return None
2024-11-08 10:19:00 -08:00
# Build the document text and page spans
2024-11-11 13:09:09 -08:00
document_text = ""
2024-11-08 10:19:00 -08:00
pdf_page_spans = []
current_char_pos = 0
2024-11-11 11:46:49 -08:00
for index, page_result in enumerate(page_results):
if page_result.response.natural_text is not None:
content = page_result.response.natural_text + ("\n" if index == len(page_results) - 1 else "")
else:
content = ""
2024-11-08 10:19:00 -08:00
start_pos = current_char_pos
document_text += content
current_char_pos = len(document_text)
2024-11-11 13:31:14 -08:00
pdf_page_spans.append([start_pos, current_char_pos, page_result.page_num])
2024-11-08 10:19:00 -08:00
if not document_text:
return None # Return None if the document text is empty
# Build the Dolma document
2024-11-08 09:59:27 -08:00
metadata = {
2024-11-08 10:19:00 -08:00
"Source-File": pdf_s3_path,
"pdf-total-pages": num_pages,
2024-11-12 08:34:25 -08:00
"total-input-tokens": sum(page.input_tokens for page in page_results),
"total-output-tokens": sum(page.output_tokens for page in page_results)
2024-11-08 10:19:00 -08:00
}
2024-11-08 09:59:27 -08:00
id_ = hashlib.sha1(document_text.encode()).hexdigest()
dolma_doc = {
"id": id_,
"text": document_text,
"source": "pdelfin",
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
"attributes": {
"pdf_page_numbers": pdf_page_spans
}
}
return dolma_doc
2024-11-12 09:33:53 -08:00
async def worker(args, queue, semaphore, worker_id):
2024-11-08 08:14:20 -08:00
while True:
2024-11-11 13:09:09 -08:00
[work_hash, pdfs] = await queue.get()
try:
2024-11-13 13:05:57 -08:00
await tracker.clear_work(worker_id)
# Wait until allowed to proceed
await semaphore.acquire()
2024-11-13 13:05:57 -08:00
# TODO: Double check that the work item has not been done already by looking at the s3 workspace
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=3600),
connector=aiohttp.TCPConnector(limit=1000)) as session:
dolma_docs = await asyncio.gather(*[process_pdf(args, session, worker_id, pdf) for pdf in pdfs])
dolma_docs = [doc for doc in dolma_docs if doc is not None]
2024-11-11 13:09:09 -08:00
# Write the Dolma documents to a local temporary file in JSONL format
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
for doc in dolma_docs:
tf.write(json.dumps(doc))
tf.write('\n')
tf.flush()
# Define the output S3 path using the work_hash
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
bucket, key = parse_s3_path(output_s3_path)
workspace_s3.upload_file(tf.name, bucket, key)
2024-11-12 09:33:53 -08:00
# Update finished token counts from successful documents
2024-11-12 12:56:35 -08:00
metrics.add_metrics(finished_input_tokens=sum(doc["metadata"]["total-input-tokens"] for doc in dolma_docs),
finished_output_tokens=sum(doc["metadata"]["total-output-tokens"] for doc in dolma_docs))
2024-11-11 13:09:09 -08:00
# Update last batch time
2024-11-11 13:31:14 -08:00
last_batch_time = time.perf_counter()
2024-11-11 13:09:09 -08:00
except Exception as e:
logger.exception(f"Exception occurred while processing work_hash {work_hash}: {e}")
finally:
queue.task_done()
2024-11-08 08:14:20 -08:00
async def sglang_server_task(args, semaphore):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
2024-11-12 15:18:04 -08:00
download_directory(args.model, model_cache_dir)
2024-11-11 10:24:47 -08:00
# Check the rope config and make sure it's got the proper key
with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
config_data = json.load(cfin)
if "rope_type" in config_data["rope_scaling"]:
del config_data["rope_scaling"]["rope_type"]
config_data["rope_scaling"]["type"] = "mrope"
with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
json.dump(config_data, cfout)
proc = await asyncio.create_subprocess_exec(
"python3",
"-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
# Make sure we kill this subprocess on exit
2024-11-11 10:24:47 -08:00
def _kill_proc():
proc.terminate()
atexit.register(_kill_proc)
2024-11-08 10:36:09 -08:00
last_running_req, last_queue_req = 0, 0 # To track transitions
can_release_automatically = False
last_semaphore_release = time.time()
async def process_line(line):
2024-11-12 11:53:29 -08:00
nonlocal last_running_req, last_queue_req, can_release_automatically, last_semaphore_release
2024-11-12 09:33:53 -08:00
sglang_logger.info(line)
match = re.search(r'#running-req: (\d+)', line)
if match:
last_running_req = int(match.group(1))
if last_running_req > 0:
can_release_automatically = True
2024-11-12 09:33:53 -08:00
# Parse the line and update semaphore if necessary
2024-11-12 08:34:25 -08:00
match = re.search(r'#queue-req: (\d+)', line)
if match:
2024-11-12 08:34:25 -08:00
queue_req = int(match.group(1))
logger.info(f"sglang running req: {last_running_req} queue req: {queue_req}")
2024-11-12 11:53:29 -08:00
if last_queue_req != 0 and queue_req == 0:
# Release the semaphore when queue_req transitions from non-zero to zero
if semaphore.locked():
semaphore.release()
last_semaphore_release = time.time()
logger.info("Semaphore released, allowing a worker to proceed.")
last_queue_req = queue_req
# And have a semaphore release automatically if there are no running requests for > 30 seconds
2024-11-12 11:53:29 -08:00
if last_running_req == 0 and can_release_automatically and time.time() - last_semaphore_release > 30 and semaphore.locked():
semaphore.release()
last_semaphore_release = time.time()
2024-11-12 11:53:29 -08:00
can_release_automatically = False
logger.info("Semaphore released due to timeout, allowing a worker to proceed.")
async def read_stream(stream):
while True:
line = await stream.readline()
if not line:
break
line = line.decode('utf-8').rstrip()
await process_line(line)
# Start tasks to read stdout and stderr
stdout_task = asyncio.create_task(read_stream(proc.stdout))
stderr_task = asyncio.create_task(read_stream(proc.stderr))
await proc.wait()
await stdout_task
await stderr_task
2024-11-08 10:36:09 -08:00
async def sglang_server_host(args, semaphore):
while True:
await sglang_server_task(args, semaphore)
2024-11-12 15:18:04 -08:00
logger.warning("SGLang server task ended")
2024-11-08 10:36:09 -08:00
async def sglang_server_ready():
2024-11-08 11:38:56 -08:00
max_attempts = 300
2024-11-08 10:19:00 -08:00
delay_sec = 1
url = 'http://localhost:30000/v1/models'
for attempt in range(1, max_attempts + 1):
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
logger.info("sglang server is ready.")
return
else:
logger.info(f"Attempt {attempt}: Unexpected status code {response.status}")
except Exception as e:
2024-11-11 14:26:15 -08:00
logger.warning(f"Attempt {attempt}: {e}")
2024-11-08 10:36:09 -08:00
2024-11-08 10:19:00 -08:00
await asyncio.sleep(delay_sec)
raise Exception("sglang server did not become ready after waiting.")
2024-11-12 13:28:39 -08:00
2024-11-12 12:56:35 -08:00
async def metrics_reporter():
while True:
2024-11-12 13:28:39 -08:00
# Leading newlines preserve table formatting in logs
logger.info("\n" + str(metrics))
logger.info("\n" + str(await tracker.get_status_table()))
2024-11-12 12:56:35 -08:00
await asyncio.sleep(10)
2024-11-13 10:25:35 -08:00
2024-11-12 15:56:51 -08:00
def submit_beaker_job(args):
2024-11-13 08:00:14 -08:00
from beaker import (
Beaker,
Constraints,
DataMount,
DataSource,
EnvVar,
ExperimentSpec,
ImageSource,
Priority,
ResultSpec,
SecretNotFound,
TaskContext,
TaskResources,
TaskSpec,
)
b = Beaker.from_env(default_workspace=args.beaker_workspace)
account = b.account.whoami()
2024-11-13 11:26:46 -08:00
owner = account.name
2024-11-13 12:35:40 -08:00
beaker_image = f"jakep/pdelfin-inference-{VERSION}"
2024-11-13 08:00:14 -08:00
2024-11-13 09:35:34 -08:00
task_name = f"pdelfin-{os.path.basename(args.workspace.rstrip('/'))}"
2024-11-13 08:00:14 -08:00
priority = "normal"
2024-11-13 10:25:35 -08:00
args_list = [arg for arg in sys.argv[1:] if arg != "--beaker"]
2024-11-13 08:00:14 -08:00
2024-11-13 11:26:46 -08:00
try:
b.secret.get(f"{owner}-WEKA_ACCESS_KEY_ID", args.beaker_workspace)
b.secret.get(f"{owner}-WEKA_SECRET_ACCESS_KEY", args.beaker_workspace)
b.secret.get(f"{owner}-AWS_CREDENTIALS_FILE", args.beaker_workspace)
except SecretNotFound:
print(f"Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace {args.beaker_workspace}? [y/n]")
if input().strip().lower() != "y":
print("Exiting...")
sys.exit(1)
b.secret.write(f"{owner}-WEKA_ACCESS_KEY_ID", os.environ.get("WEKA_ACCESS_KEY_ID", ""), args.beaker_workspace)
b.secret.write(f"{owner}-WEKA_SECRET_ACCESS_KEY", os.environ.get("WEKA_SECRET_ACCESS_KEY", ""), args.beaker_workspace)
b.secret.write(f"{owner}-AWS_CREDENTIALS_FILE", open(os.path.join(os.path.expanduser('~'), '.aws', 'credentials')).read(), args.beaker_workspace)
2024-11-13 08:00:14 -08:00
# Create the experiment spec
experiment_spec = ExperimentSpec(
budget="ai2/oe-data",
description=task_name,
tasks=[
TaskSpec(
name=task_name,
propagate_failure=False,
propagate_preemption=False,
replicas=1,
context=TaskContext(
priority=Priority(priority),
preemptible=True,
),
image=ImageSource(beaker=beaker_image),
command=["python", "-m", "pdelfin.beakerpipeline"] + args_list,
env_vars=[
EnvVar(name="BEAKER_JOB_NAME", value=task_name),
2024-11-13 11:26:46 -08:00
EnvVar(name="OWNER", value=owner),
EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
2024-11-13 08:00:14 -08:00
],
resources=TaskResources(gpu_count=1),
constraints=Constraints(cluster=args.beaker_cluster),
result=ResultSpec(path="/noop-results"),
)
],
)
experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace)
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
2024-11-12 15:56:51 -08:00
2024-11-12 13:28:39 -08:00
2024-11-08 08:14:20 -08:00
async def main():
2024-11-07 18:21:23 +00:00
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
2024-11-12 15:18:04 -08:00
parser.add_argument('--workers', type=int, default=5, help='Number of workers to run at a time')
2024-11-12 15:56:51 -08:00
# Model parameters
2024-11-07 19:01:45 +00:00
parser.add_argument('--model', help='List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access',
default=["weka://oe-data-default/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/best_bf16/",
"gs://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/",
"s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/"])
parser.add_argument('--model_max_context', type=int, default="8192", help="Maximum context length that the model was fine tuned under")
parser.add_argument('--model_chat_template', type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
2024-11-12 15:56:51 -08:00
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
# Beaker/job running stuff
parser.add_argument('--beaker', action='store_true', help='Submit this job to beaker instead of running locally')
parser.add_argument('--beaker_workspace', help='Beaker workspace to submit to', default='ai2/pdelfin')
2024-11-13 08:00:14 -08:00
parser.add_argument('--beaker_cluster', help='Beaker clusters you want to run on', default=["ai2/jupiter-cirrascale-2", "ai2/pluto-cirrascale", "ai2/saturn-cirrascale"])
args = parser.parse_args()
2024-11-13 11:26:46 -08:00
if "AWS_CREDENTIALS_FILE" in os.environ:
cred_path = os.path.join(os.path.expanduser('~'), '.aws', 'credentials')
2024-11-13 12:35:40 -08:00
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
2024-11-13 11:26:46 -08:00
with open(cred_path, "w") as f:
f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
2024-11-07 18:21:23 +00:00
if args.workspace_profile:
2024-11-08 11:04:58 -08:00
global workspace_s3
2024-11-07 18:21:23 +00:00
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.client("s3")
if args.pdf_profile:
2024-11-08 11:04:58 -08:00
global pdf_s3
2024-11-07 18:21:23 +00:00
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3")
2024-11-07 23:24:01 +00:00
check_poppler_version()
2024-11-07 21:08:46 +00:00
2024-11-07 18:21:23 +00:00
if args.pdfs:
2024-11-13 12:59:52 -08:00
logger.info("Got --pdfs argument, going to add to the work queue")
2024-11-08 08:14:20 -08:00
await populate_pdf_work_queue(args)
2024-11-12 15:56:51 -08:00
if args.beaker:
2024-11-13 08:00:14 -08:00
submit_beaker_job(args)
2024-11-12 15:56:51 -08:00
return
2024-11-13 12:59:52 -08:00
logger.info(f"Starting pipeline with PID {os.getpid()}")
# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
2024-11-08 08:14:20 -08:00
work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
2024-11-08 10:19:00 -08:00
await sglang_server_ready()
2024-11-07 18:21:23 +00:00
2024-11-12 12:56:35 -08:00
metrics_task = asyncio.create_task(metrics_reporter())
2024-11-08 08:14:20 -08:00
# Create worker tasks to process the queue concurrently.
worker_tasks = []
2024-11-08 08:14:20 -08:00
for i in range(args.workers):
2024-11-12 09:33:53 -08:00
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
worker_tasks.append(task)
2024-11-07 18:21:23 +00:00
2024-11-08 08:14:20 -08:00
# Wait for the queue to be fully processed
await work_queue.join()
# Cancel our worker tasks.
for task in worker_tasks:
2024-11-08 08:14:20 -08:00
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*worker_tasks, return_exceptions=True)
# Wait for server to stop
sglang_server.cancel()
await sglang_server
2024-11-12 12:56:35 -08:00
metrics_task.cancel()
await metrics_task
2024-11-08 08:14:20 -08:00
if __name__ == "__main__":
asyncio.run(main())
2024-11-07 18:21:23 +00:00
2024-11-07 21:00:51 +00:00
# TODO
2024-11-07 18:21:23 +00:00
# Possible future addon, in beaker, discover other nodes on this same job
# Send them a message when you take a work item off the queue