mirror of
https://github.com/allenai/olmocr.git
synced 2025-08-31 12:25:54 +00:00
Better work queue
This commit is contained in:
parent
04429b2862
commit
e499413089
@ -24,9 +24,10 @@ from PIL import Image
|
||||
from pypdf import PdfReader
|
||||
from functools import partial
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, List, Dict, Set
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
from pdelfin.s3_queue import S3WorkQueue, WorkItem
|
||||
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
|
||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||
from pdelfin.prompts import build_finetuning_prompt, PageResponse
|
||||
@ -123,160 +124,6 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
|
||||
}
|
||||
|
||||
|
||||
def compute_workgroup_sha1(work_group: list[str]) -> str:
|
||||
sha1 = hashlib.sha1()
|
||||
# Ensure consistent ordering by sorting the list
|
||||
for pdf in sorted(work_group):
|
||||
sha1.update(pdf.encode('utf-8'))
|
||||
return sha1.hexdigest()
|
||||
|
||||
|
||||
async def populate_pdf_work_queue(args):
|
||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
||||
|
||||
if args.pdfs.startswith("s3://"):
|
||||
logger.info(f"Expanding s3 glob at {args.pdfs}")
|
||||
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
|
||||
elif os.path.exists(args.pdfs):
|
||||
logger.info(f"Loading file at {args.pdfs}")
|
||||
with open(args.pdfs, "r") as f:
|
||||
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
|
||||
else:
|
||||
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
||||
|
||||
all_pdfs = set(all_pdfs)
|
||||
logger.info(f"Found {len(all_pdfs):,} total pdf paths")
|
||||
|
||||
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
||||
|
||||
# Parse existing work items into groups
|
||||
existing_groups = {}
|
||||
for line in existing_lines:
|
||||
if line.strip():
|
||||
parts = line.strip().split(",")
|
||||
group_hash = parts[0]
|
||||
group_pdfs = parts[1:]
|
||||
existing_groups[group_hash] = group_pdfs
|
||||
existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs)
|
||||
|
||||
logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")
|
||||
|
||||
# Remove existing PDFs from all_pdfs
|
||||
new_pdfs = all_pdfs - existing_pdf_set
|
||||
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")
|
||||
|
||||
sample_size = min(100, len(new_pdfs))
|
||||
sampled_pdfs = random.sample(list(new_pdfs), sample_size)
|
||||
|
||||
page_counts = []
|
||||
|
||||
for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimial length"):
|
||||
try:
|
||||
# Download the PDF to a temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
|
||||
s3_bucket, s3_key = parse_s3_path(pdf)
|
||||
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
|
||||
tmp_file.flush()
|
||||
reader = PdfReader(tmp_file.name)
|
||||
page_counts.append(len(reader.pages))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {pdf}: {e}")
|
||||
|
||||
if page_counts:
|
||||
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
|
||||
else:
|
||||
logger.warning("Could not read any PDFs to estimate average page count.")
|
||||
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
|
||||
|
||||
group_size = max(1, int(args.pages_per_group / avg_pages_per_pdf))
|
||||
logger.info(f"Calculated group_size: {group_size} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
|
||||
|
||||
new_groups = []
|
||||
current_group = []
|
||||
for pdf in sorted(new_pdfs): # Sort for consistency
|
||||
current_group.append(pdf)
|
||||
if len(current_group) == group_size:
|
||||
group_hash = compute_workgroup_sha1(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
current_group = []
|
||||
if current_group:
|
||||
group_hash = compute_workgroup_sha1(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
|
||||
logger.info(f"Created {len(new_groups):,} new work groups")
|
||||
|
||||
# Combine existing groups with new groups
|
||||
combined_groups = existing_groups.copy()
|
||||
for group_hash, group_pdfs in new_groups:
|
||||
combined_groups[group_hash] = group_pdfs
|
||||
|
||||
# Prepare lines to write back
|
||||
combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()]
|
||||
|
||||
# Upload the combined work items back to S3
|
||||
if new_groups:
|
||||
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
|
||||
|
||||
logger.info("Completed adding new PDFs.")
|
||||
|
||||
async def load_pdf_work_queue(args) -> asyncio.Queue:
|
||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
||||
output_glob = os.path.join(args.workspace, "dolma_documents", "*.jsonl")
|
||||
|
||||
# Define the two blocking I/O operations
|
||||
download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path)
|
||||
expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob)
|
||||
|
||||
# Run both tasks concurrently
|
||||
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
|
||||
|
||||
# Process the work queue lines
|
||||
work_queue = {
|
||||
parts[0]: parts[1:]
|
||||
for line in work_queue_lines
|
||||
if (parts := line.strip().split(",")) and line.strip()
|
||||
}
|
||||
|
||||
# Extract done work hashes
|
||||
done_work_hashes = {
|
||||
os.path.basename(item)[len('output_'):-len('.jsonl')]
|
||||
for item in done_work_items
|
||||
if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl')
|
||||
}
|
||||
|
||||
# Determine remaining work
|
||||
remaining_work_hashes = set(work_queue) - done_work_hashes
|
||||
#remaining_work_hashes = set(["0e779f21fbb75d38ed4242c7e5fe57fa9a636bac"]) # If you want to debug with a specific work hash
|
||||
remaining_work_queue = {
|
||||
hash_: work_queue[hash_]
|
||||
for hash_ in remaining_work_hashes
|
||||
}
|
||||
|
||||
# Populate the asyncio.Queue with remaining work
|
||||
queue = asyncio.Queue()
|
||||
shuffled_items = list(remaining_work_queue.items())
|
||||
random.shuffle(shuffled_items)
|
||||
|
||||
for work, pdfs in shuffled_items:
|
||||
await queue.put((work, pdfs))
|
||||
|
||||
return queue
|
||||
|
||||
async def work_item_completed(args, work_hash: str) -> bool:
|
||||
# Check if work item has already been completed
|
||||
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
|
||||
bucket, key = parse_s3_path(output_s3_path)
|
||||
|
||||
try:
|
||||
# Check if the output file already exists
|
||||
await asyncio.to_thread(workspace_s3.head_object, Bucket=bucket, Key=key)
|
||||
return True
|
||||
except workspace_s3.exceptions.ClientError as e:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||
MAX_RETRIES = 3
|
||||
@ -442,31 +289,32 @@ def build_dolma_document(pdf_s3_path, page_results):
|
||||
}
|
||||
return dolma_doc
|
||||
|
||||
async def worker(args, queue, semaphore, worker_id):
|
||||
|
||||
async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
|
||||
while True:
|
||||
[work_hash, pdfs] = await queue.get()
|
||||
# Wait until allowed to proceed
|
||||
await semaphore.acquire()
|
||||
|
||||
try:
|
||||
await tracker.clear_work(worker_id)
|
||||
work_item = await work_queue.get_work()
|
||||
|
||||
# Wait until allowed to proceed
|
||||
await semaphore.acquire()
|
||||
if work_item is None:
|
||||
logger.info(f"Worker {worker_id} exiting due to empty queue")
|
||||
semaphore.release()
|
||||
break
|
||||
|
||||
if await work_item_completed(args, work_hash):
|
||||
logger.info(f"Work {work_hash} was already completed, skipping")
|
||||
continue
|
||||
else:
|
||||
logger.info(f"Proceeding with {work_hash} on worker {worker_id}")
|
||||
logger.info(f"Worker {worker_id} processing work item {work_item.hash}")
|
||||
await tracker.clear_work(worker_id)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600),
|
||||
connector=aiohttp.TCPConnector(limit=1000)) as session:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in pdfs]
|
||||
logger.info(f"Created all tasks for {work_hash}")
|
||||
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in work_item.s3_work_paths]
|
||||
logger.info(f"Created all tasks for {work_item.hash}")
|
||||
|
||||
logger.info(f"Finished TaskGroup for worker on {work_hash}")
|
||||
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
|
||||
|
||||
logger.info(f"Closed ClientSession for {work_hash}")
|
||||
logger.info(f"Closed ClientSession for {work_item.hash}")
|
||||
|
||||
dolma_docs = []
|
||||
for task in dolma_tasks:
|
||||
@ -479,7 +327,7 @@ async def worker(args, queue, semaphore, worker_id):
|
||||
if result is not None:
|
||||
dolma_docs.append(result)
|
||||
|
||||
logger.info(f"Got {len(dolma_docs)} docs for {work_hash}")
|
||||
logger.info(f"Got {len(dolma_docs)} docs for {work_item.hash}")
|
||||
|
||||
# Write the Dolma documents to a local temporary file in JSONL format
|
||||
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
|
||||
@ -489,7 +337,7 @@ async def worker(args, queue, semaphore, worker_id):
|
||||
tf.flush()
|
||||
|
||||
# Define the output S3 path using the work_hash
|
||||
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')
|
||||
output_s3_path = os.path.join(args.workspace, 'results', f'output_{work_item.hash}.jsonl')
|
||||
|
||||
bucket, key = parse_s3_path(output_s3_path)
|
||||
workspace_s3.upload_file(tf.name, bucket, key)
|
||||
@ -501,9 +349,10 @@ async def worker(args, queue, semaphore, worker_id):
|
||||
# Update last batch time
|
||||
last_batch_time = time.perf_counter()
|
||||
except Exception as e:
|
||||
logger.exception(f"Exception occurred while processing work_hash {work_hash}: {e}")
|
||||
logger.exception(f"Exception occurred while processing work_hash {work_item.hash}: {e}")
|
||||
finally:
|
||||
queue.task_done()
|
||||
await work_queue.mark_done(work_item)
|
||||
semaphore.release()
|
||||
|
||||
|
||||
async def sglang_server_task(args, semaphore):
|
||||
@ -563,6 +412,7 @@ async def sglang_server_task(args, semaphore):
|
||||
|
||||
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
|
||||
server_printed_ready_message = True
|
||||
last_semaphore_release = time.time()
|
||||
|
||||
match = re.search(r'#running-req: (\d+)', line)
|
||||
if match:
|
||||
@ -631,10 +481,10 @@ async def sglang_server_ready():
|
||||
raise Exception("sglang server did not become ready after waiting.")
|
||||
|
||||
|
||||
async def metrics_reporter(queue):
|
||||
async def metrics_reporter(work_queue):
|
||||
while True:
|
||||
# Leading newlines preserve table formatting in logs
|
||||
logger.info(f"Queue remaining: {queue.qsize()}")
|
||||
logger.info(f"Queue remaining: {work_queue.size}")
|
||||
logger.info("\n" + str(metrics))
|
||||
logger.info("\n" + str(await tracker.get_status_table()))
|
||||
await asyncio.sleep(10)
|
||||
@ -716,13 +566,14 @@ def submit_beaker_job(args):
|
||||
|
||||
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
|
||||
|
||||
|
||||
def print_stats(args):
|
||||
import concurrent.futures
|
||||
from tqdm import tqdm
|
||||
|
||||
# Get total work items and completed items
|
||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
||||
output_glob = os.path.join(args.workspace, "dolma_documents", "*.jsonl")
|
||||
index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd")
|
||||
output_glob = os.path.join(args.workspace, "results", "*.jsonl")
|
||||
|
||||
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
||||
done_work_items = expand_s3_glob(workspace_s3, output_glob)
|
||||
@ -825,9 +676,54 @@ async def main():
|
||||
|
||||
check_poppler_version()
|
||||
|
||||
# Create work queue
|
||||
work_queue = S3WorkQueue(workspace_s3, args.workspace)
|
||||
|
||||
if args.pdfs:
|
||||
logger.info("Got --pdfs argument, going to add to the work queue")
|
||||
await populate_pdf_work_queue(args)
|
||||
|
||||
# Expand s3 paths
|
||||
if args.pdfs.startswith("s3://"):
|
||||
logger.info(f"Expanding s3 glob at {args.pdfs}")
|
||||
s3_work_paths = expand_s3_glob(pdf_s3, args.pdfs)
|
||||
elif os.path.exists(args.pdfs):
|
||||
logger.info(f"Loading file at {args.pdfs}")
|
||||
with open(args.pdfs, "r") as f:
|
||||
s3_work_paths = list(filter(None, (line.strip() for line in f)))
|
||||
else:
|
||||
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
||||
|
||||
s3_work_paths = set(s3_work_paths)
|
||||
logger.info(f"Found {len(s3_work_paths):,} total pdf paths to add")
|
||||
|
||||
# Estimate average pages per pdf
|
||||
sample_size = min(100, len(s3_work_paths))
|
||||
sampled_pdfs = random.sample(list(s3_work_paths), sample_size)
|
||||
page_counts = []
|
||||
|
||||
for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"):
|
||||
try:
|
||||
# Download the PDF to a temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
|
||||
s3_bucket, s3_key = parse_s3_path(pdf)
|
||||
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
|
||||
tmp_file.flush()
|
||||
reader = PdfReader(tmp_file.name)
|
||||
page_counts.append(len(reader.pages))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {pdf}: {e}")
|
||||
|
||||
if page_counts:
|
||||
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
|
||||
else:
|
||||
logger.warning("Could not read any PDFs to estimate average page count.")
|
||||
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
|
||||
|
||||
items_per_group = max(1, int(args.pages_per_group / avg_pages_per_pdf))
|
||||
logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
|
||||
|
||||
# Now call populate_queue
|
||||
await work_queue.populate_queue(s3_work_paths, items_per_group)
|
||||
|
||||
if args.stats:
|
||||
print_stats(args)
|
||||
@ -839,6 +735,9 @@ async def main():
|
||||
|
||||
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
||||
|
||||
# Initialize the work queue
|
||||
await work_queue.initialize_queue()
|
||||
|
||||
# Create a semaphore to control worker access
|
||||
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
|
||||
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
|
||||
@ -847,9 +746,6 @@ async def main():
|
||||
|
||||
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
|
||||
|
||||
work_queue = await load_pdf_work_queue(args)
|
||||
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
|
||||
|
||||
await sglang_server_ready()
|
||||
|
||||
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
||||
@ -860,16 +756,9 @@ async def main():
|
||||
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
|
||||
worker_tasks.append(task)
|
||||
|
||||
# Wait for the queue to be fully processed
|
||||
await work_queue.join()
|
||||
# Wait for all worker tasks to finish
|
||||
await asyncio.gather(*worker_tasks)
|
||||
|
||||
# Cancel our worker tasks.
|
||||
for task in worker_tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait until all worker tasks are cancelled.
|
||||
await asyncio.gather(*worker_tasks, return_exceptions=True)
|
||||
|
||||
# Wait for server to stop
|
||||
process_pool.shutdown(wait=False)
|
||||
|
||||
@ -877,11 +766,12 @@ async def main():
|
||||
metrics_task.cancel()
|
||||
logger.info("Work done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
# TODO
|
||||
# - Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied)
|
||||
# X Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied)
|
||||
# X Fix the queue release mechanism so that it just does a timeout, based on zero queue size only, so you don't block things
|
||||
# - Add logging of failed pages and have the stats function read them
|
||||
# X Add the page rotation check and mechanism
|
||||
|
@ -37,12 +37,6 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
|
||||
self.assertEqual(hash1, hash2)
|
||||
|
||||
# Verify hash is actually SHA1
|
||||
sha1 = hashlib.sha1()
|
||||
for path in sorted(paths):
|
||||
sha1.update(path.encode('utf-8'))
|
||||
self.assertEqual(hash1, sha1.hexdigest())
|
||||
|
||||
def test_init(self):
|
||||
"""Test initialization of S3WorkQueue"""
|
||||
client = Mock()
|
||||
@ -51,7 +45,6 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
|
||||
self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
|
||||
self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
|
||||
self.assertIsInstance(queue._queue, asyncio.Queue)
|
||||
|
||||
def asyncSetUp(self):
|
||||
"""Set up async test fixtures"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user