mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-31 20:36:21 +00:00
Reworking to be async
This commit is contained in:
parent
a103ce730f
commit
a39350e074
@ -9,6 +9,7 @@ import subprocess
|
|||||||
import atexit
|
import atexit
|
||||||
import hashlib
|
import hashlib
|
||||||
import base64
|
import base64
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -73,8 +74,123 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
|
|||||||
sha1.update(pdf.encode('utf-8'))
|
sha1.update(pdf.encode('utf-8'))
|
||||||
return sha1.hexdigest()
|
return sha1.hexdigest()
|
||||||
|
|
||||||
|
async def start_sglang_server(args):
|
||||||
|
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
||||||
|
download_directory(args.model, model_cache_dir)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
# Start up the sglang server
|
||||||
|
sglang_process = subprocess.Popen([
|
||||||
|
"python3", "-m", "sglang.launch_server",
|
||||||
|
"--model-path", model_cache_dir,
|
||||||
|
"--chat-template", args.model_chat_template,
|
||||||
|
"--context-length", str(args.model_max_context),
|
||||||
|
])
|
||||||
|
|
||||||
|
async def populate_pdf_work_queue(args):
|
||||||
|
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
||||||
|
|
||||||
|
if args.pdfs.startswith("s3://"):
|
||||||
|
logger.info(f"Expanding s3 glob at {args.pdfs}")
|
||||||
|
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
|
||||||
|
elif os.path.exists(args.pdfs):
|
||||||
|
logger.info(f"Loading file at {args.pdfs}")
|
||||||
|
with open(args.pdfs, "r") as f:
|
||||||
|
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
|
||||||
|
else:
|
||||||
|
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
||||||
|
|
||||||
|
all_pdfs = set(all_pdfs)
|
||||||
|
logger.info(f"Found {len(all_pdfs):,} total pdf paths")
|
||||||
|
|
||||||
|
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
||||||
|
|
||||||
|
# Parse existing work items into groups
|
||||||
|
existing_groups = {}
|
||||||
|
for line in existing_lines:
|
||||||
|
if line.strip():
|
||||||
|
parts = line.strip().split(",")
|
||||||
|
group_hash = parts[0]
|
||||||
|
group_pdfs = parts[1:]
|
||||||
|
existing_groups[group_hash] = group_pdfs
|
||||||
|
existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs)
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")
|
||||||
|
|
||||||
|
# Remove existing PDFs from all_pdfs
|
||||||
|
new_pdfs = all_pdfs - existing_pdf_set
|
||||||
|
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")
|
||||||
|
|
||||||
|
# Group the new PDFs into chunks of group_size
|
||||||
|
# TODO: Figure out the group size automatically by sampling a few pdfs, and taking the mean/median number of pages, etc.
|
||||||
|
new_groups = []
|
||||||
|
current_group = []
|
||||||
|
for pdf in sorted(new_pdfs): # Sort for consistency
|
||||||
|
current_group.append(pdf)
|
||||||
|
if len(current_group) == args.group_size:
|
||||||
|
group_hash = compute_workgroup_sha1(current_group)
|
||||||
|
new_groups.append((group_hash, current_group))
|
||||||
|
current_group = []
|
||||||
|
if current_group:
|
||||||
|
group_hash = compute_workgroup_sha1(current_group)
|
||||||
|
new_groups.append((group_hash, current_group))
|
||||||
|
|
||||||
|
logger.info(f"Created {len(new_groups):,} new work groups")
|
||||||
|
|
||||||
|
# Combine existing groups with new groups
|
||||||
|
combined_groups = existing_groups.copy()
|
||||||
|
for group_hash, group_pdfs in new_groups:
|
||||||
|
combined_groups[group_hash] = group_pdfs
|
||||||
|
|
||||||
|
# Prepare lines to write back
|
||||||
|
combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()]
|
||||||
|
|
||||||
|
# Upload the combined work items back to S3
|
||||||
|
if new_groups:
|
||||||
|
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
|
||||||
|
|
||||||
|
logger.info("Completed adding new PDFs.")
|
||||||
|
|
||||||
|
async def load_pdf_work_queue(args) -> asyncio.Queue:
|
||||||
|
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
||||||
|
|
||||||
|
# Read in the work queue from s3
|
||||||
|
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
||||||
|
work_queue = {}
|
||||||
|
for line in work_queue_lines:
|
||||||
|
if line.strip():
|
||||||
|
parts = line.strip().split(",")
|
||||||
|
group_hash = parts[0]
|
||||||
|
group_pdfs = parts[1:]
|
||||||
|
work_queue[group_hash] = group_pdfs
|
||||||
|
|
||||||
|
# Read in the done items from the s3 workspace
|
||||||
|
done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl")
|
||||||
|
done_work_hashes = set()
|
||||||
|
for item in done_work_items:
|
||||||
|
filename = os.path.basename(item)
|
||||||
|
if filename.startswith('output_') and filename.endswith('.jsonl'):
|
||||||
|
group_hash = filename[len('output_'):-len('.jsonl')]
|
||||||
|
done_work_hashes.add(group_hash)
|
||||||
|
|
||||||
|
remaining_work_hashes = set(work_queue.keys()) - done_work_hashes
|
||||||
|
remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes}
|
||||||
|
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
|
||||||
|
for work in remaining_work_queue:
|
||||||
|
await queue.put((work, remaining_work_queue[work]))
|
||||||
|
|
||||||
|
return queue
|
||||||
|
|
||||||
|
async def worker(args, queue):
|
||||||
|
while True:
|
||||||
|
work = await queue.get()
|
||||||
|
|
||||||
|
logger.info(f"Got work to do for {work}")
|
||||||
|
queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
|
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
|
||||||
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
|
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
|
||||||
@ -101,131 +217,67 @@ if __name__ == '__main__':
|
|||||||
pdf_session = boto3.Session(profile_name=args.pdf_profile)
|
pdf_session = boto3.Session(profile_name=args.pdf_profile)
|
||||||
pdf_s3 = pdf_session.client("s3")
|
pdf_s3 = pdf_session.client("s3")
|
||||||
|
|
||||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
|
||||||
check_poppler_version()
|
check_poppler_version()
|
||||||
|
|
||||||
# Check list of pdfs and that it matches what's in the workspace
|
|
||||||
if args.pdfs:
|
if args.pdfs:
|
||||||
if args.pdfs.startswith("s3://"):
|
await populate_pdf_work_queue(args)
|
||||||
logger.info(f"Expanding s3 glob at {args.pdfs}")
|
|
||||||
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
|
|
||||||
elif os.path.exists(args.pdfs):
|
|
||||||
logger.info(f"Loading file at {args.pdfs}")
|
|
||||||
with open(args.pdfs, "r") as f:
|
|
||||||
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
|
|
||||||
else:
|
|
||||||
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
|
||||||
|
|
||||||
all_pdfs = set(all_pdfs)
|
work_queue = await load_pdf_work_queue(args)
|
||||||
logger.info(f"Found {len(all_pdfs):,} total pdf paths")
|
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
|
||||||
|
|
||||||
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
|
||||||
|
|
||||||
# Parse existing work items into groups
|
# Create worker tasks to process the queue concurrently.
|
||||||
existing_groups = {}
|
tasks = []
|
||||||
for line in existing_lines:
|
for i in range(args.workers):
|
||||||
if line.strip():
|
task = asyncio.create_task(worker(args, work_queue))
|
||||||
parts = line.strip().split(",")
|
tasks.append(task)
|
||||||
group_hash = parts[0]
|
|
||||||
group_pdfs = parts[1:]
|
|
||||||
existing_groups[group_hash] = group_pdfs
|
|
||||||
existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs)
|
|
||||||
|
|
||||||
logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")
|
# Wait for the queue to be fully processed
|
||||||
|
await work_queue.join()
|
||||||
|
|
||||||
# Remove existing PDFs from all_pdfs
|
# Cancel our worker tasks.
|
||||||
new_pdfs = all_pdfs - existing_pdf_set
|
for task in tasks:
|
||||||
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")
|
task.cancel()
|
||||||
|
|
||||||
# Group the new PDFs into chunks of group_size
|
# Wait until all worker tasks are cancelled.
|
||||||
# TODO: Figure out the group size automatically by sampling a few pdfs, and taking the mean/median number of pages, etc.
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
new_groups = []
|
|
||||||
current_group = []
|
|
||||||
for pdf in sorted(new_pdfs): # Sort for consistency
|
|
||||||
current_group.append(pdf)
|
|
||||||
if len(current_group) == args.group_size:
|
|
||||||
group_hash = compute_workgroup_sha1(current_group)
|
|
||||||
new_groups.append((group_hash, current_group))
|
|
||||||
current_group = []
|
|
||||||
if current_group:
|
|
||||||
group_hash = compute_workgroup_sha1(current_group)
|
|
||||||
new_groups.append((group_hash, current_group))
|
|
||||||
|
|
||||||
logger.info(f"Created {len(new_groups):,} new work groups")
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
# Combine existing groups with new groups
|
|
||||||
combined_groups = existing_groups.copy()
|
|
||||||
for group_hash, group_pdfs in new_groups:
|
|
||||||
combined_groups[group_hash] = group_pdfs
|
|
||||||
|
|
||||||
# Prepare lines to write back
|
|
||||||
combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()]
|
|
||||||
|
|
||||||
# Upload the combined work items back to S3
|
|
||||||
if new_groups:
|
|
||||||
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
|
|
||||||
|
|
||||||
logger.info("Completed adding new PDFs.")
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
|
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
|
||||||
# If not, then your job is to do the actual work
|
# If not, then your job is to do the actual work
|
||||||
|
|
||||||
# Download the model from the best place available
|
# Download the model from the best place available
|
||||||
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
|
||||||
download_directory(args.model, model_cache_dir)
|
|
||||||
|
|
||||||
# Start up the sglang server
|
|
||||||
sglang_process = subprocess.Popen([
|
|
||||||
"python3", "-m", "sglang.launch_server",
|
|
||||||
"--model-path", model_cache_dir,
|
|
||||||
"--chat-template", args.model_chat_template,
|
|
||||||
"--context-length", str(args.model_max_context),
|
|
||||||
])
|
|
||||||
|
|
||||||
# Register atexit function and signal handlers to guarantee process termination
|
# Register atexit function and signal handlers to guarantee process termination
|
||||||
def terminate_processes():
|
# def terminate_processes():
|
||||||
print("Terminating child processes...")
|
# print("Terminating child processes...")
|
||||||
sglang_process.terminate()
|
# sglang_process.terminate()
|
||||||
try:
|
# try:
|
||||||
sglang_process.wait(timeout=30)
|
# sglang_process.wait(timeout=30)
|
||||||
except subprocess.TimeoutExpired:
|
# except subprocess.TimeoutExpired:
|
||||||
print("Forcing termination of child processes.")
|
# print("Forcing termination of child processes.")
|
||||||
sglang_process.kill()
|
# sglang_process.kill()
|
||||||
print("Child processes terminated.")
|
# print("Child processes terminated.")
|
||||||
|
|
||||||
atexit.register(terminate_processes)
|
# atexit.register(terminate_processes)
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
# def signal_handler(sig, frame):
|
||||||
terminate_processes()
|
# terminate_processes()
|
||||||
sys.exit(0)
|
# sys.exit(0)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
# signal.signal(signal.SIGINT, signal_handler)
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
# signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
# Read in the work queue from s3
|
|
||||||
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
# logger.info(f"Remaining work items: {len(remaining_work_queue)}")
|
||||||
work_queue = {}
|
|
||||||
for line in work_queue_lines:
|
|
||||||
if line.strip():
|
|
||||||
parts = line.strip().split(",")
|
|
||||||
group_hash = parts[0]
|
|
||||||
group_pdfs = parts[1:]
|
|
||||||
work_queue[group_hash] = group_pdfs
|
|
||||||
|
|
||||||
# Read in the done items from the s3 workspace
|
|
||||||
done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl")
|
|
||||||
done_work_hashes = set()
|
|
||||||
for item in done_work_items:
|
|
||||||
filename = os.path.basename(item)
|
|
||||||
if filename.startswith('output_') and filename.endswith('.jsonl'):
|
|
||||||
group_hash = filename[len('output_'):-len('.jsonl')]
|
|
||||||
done_work_hashes.add(group_hash)
|
|
||||||
|
|
||||||
remaining_work_hashes = set(work_queue.keys()) - done_work_hashes
|
|
||||||
remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes}
|
|
||||||
|
|
||||||
logger.info(f"Remaining work items: {len(remaining_work_queue)}")
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# Spawn up to N workers to do:
|
# Spawn up to N workers to do:
|
||||||
@ -238,12 +290,12 @@ if __name__ == '__main__':
|
|||||||
# Possible future addon, in beaker, discover other nodes on this same job
|
# Possible future addon, in beaker, discover other nodes on this same job
|
||||||
# Send them a message when you take a work item off the queue
|
# Send them a message when you take a work item off the queue
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
while True:
|
# while True:
|
||||||
time.sleep(1)
|
# time.sleep(1)
|
||||||
|
|
||||||
if sglang_process.returncode is not None:
|
# if sglang_process.returncode is not None:
|
||||||
logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.")
|
# logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.")
|
||||||
except KeyboardInterrupt:
|
# except KeyboardInterrupt:
|
||||||
logger.info("Got keyboard interrupt, exiting everything")
|
# logger.info("Got keyboard interrupt, exiting everything")
|
||||||
sys.exit(1)
|
# sys.exit(1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user