diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index bd1fb70..d263058 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -24,9 +24,10 @@ from PIL import Image from pypdf import PdfReader from functools import partial from dataclasses import dataclass -from typing import Optional +from typing import Optional, Tuple, List, Dict, Set from concurrent.futures import ProcessPoolExecutor +from pdelfin.s3_queue import S3WorkQueue, WorkItem from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory from pdelfin.data.renderpdf import render_pdf_to_base64png from pdelfin.prompts import build_finetuning_prompt, PageResponse @@ -123,160 +124,6 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_ } -def compute_workgroup_sha1(work_group: list[str]) -> str: - sha1 = hashlib.sha1() - # Ensure consistent ordering by sorting the list - for pdf in sorted(work_group): - sha1.update(pdf.encode('utf-8')) - return sha1.hexdigest() - - -async def populate_pdf_work_queue(args): - index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") - - if args.pdfs.startswith("s3://"): - logger.info(f"Expanding s3 glob at {args.pdfs}") - all_pdfs = expand_s3_glob(pdf_s3, args.pdfs) - elif os.path.exists(args.pdfs): - logger.info(f"Loading file at {args.pdfs}") - with open(args.pdfs, "r") as f: - all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs")))) - else: - raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)") - - all_pdfs = set(all_pdfs) - logger.info(f"Found {len(all_pdfs):,} total pdf paths") - - existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path) - - # Parse existing work items into groups - existing_groups = {} - for line in existing_lines: - if line.strip(): - parts = line.strip().split(",") - group_hash = parts[0] - group_pdfs = parts[1:] - existing_groups[group_hash] = group_pdfs - existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs) - - logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace") - - # Remove existing PDFs from all_pdfs - new_pdfs = all_pdfs - existing_pdf_set - logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace") - - sample_size = min(100, len(new_pdfs)) - sampled_pdfs = random.sample(list(new_pdfs), sample_size) - - page_counts = [] - - for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimial length"): - try: - # Download the PDF to a temp file - with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file: - s3_bucket, s3_key = parse_s3_path(pdf) - pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file) - tmp_file.flush() - reader = PdfReader(tmp_file.name) - page_counts.append(len(reader.pages)) - except Exception as e: - logger.warning(f"Failed to read {pdf}: {e}") - - if page_counts: - avg_pages_per_pdf = sum(page_counts) / len(page_counts) - else: - logger.warning("Could not read any PDFs to estimate average page count.") - avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails - - group_size = max(1, int(args.pages_per_group / avg_pages_per_pdf)) - logger.info(f"Calculated group_size: {group_size} based on average pages per PDF: {avg_pages_per_pdf:.2f}") - - new_groups = [] - current_group = [] - for pdf in sorted(new_pdfs): # Sort for consistency - current_group.append(pdf) - if len(current_group) == group_size: - group_hash = compute_workgroup_sha1(current_group) - new_groups.append((group_hash, current_group)) - current_group = [] - if current_group: - group_hash = compute_workgroup_sha1(current_group) - new_groups.append((group_hash, current_group)) - - logger.info(f"Created {len(new_groups):,} new work groups") - - # Combine existing groups with new groups - combined_groups = existing_groups.copy() - for group_hash, group_pdfs in new_groups: - combined_groups[group_hash] = group_pdfs - - # Prepare lines to write back - combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()] - - # Upload the combined work items back to S3 - if new_groups: - upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines) - - logger.info("Completed adding new PDFs.") - -async def load_pdf_work_queue(args) -> asyncio.Queue: - index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") - output_glob = os.path.join(args.workspace, "dolma_documents", "*.jsonl") - - # Define the two blocking I/O operations - download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path) - expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob) - - # Run both tasks concurrently - work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task) - - # Process the work queue lines - work_queue = { - parts[0]: parts[1:] - for line in work_queue_lines - if (parts := line.strip().split(",")) and line.strip() - } - - # Extract done work hashes - done_work_hashes = { - os.path.basename(item)[len('output_'):-len('.jsonl')] - for item in done_work_items - if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl') - } - - # Determine remaining work - remaining_work_hashes = set(work_queue) - done_work_hashes - #remaining_work_hashes = set(["0e779f21fbb75d38ed4242c7e5fe57fa9a636bac"]) # If you want to debug with a specific work hash - remaining_work_queue = { - hash_: work_queue[hash_] - for hash_ in remaining_work_hashes - } - - # Populate the asyncio.Queue with remaining work - queue = asyncio.Queue() - shuffled_items = list(remaining_work_queue.items()) - random.shuffle(shuffled_items) - - for work, pdfs in shuffled_items: - await queue.put((work, pdfs)) - - return queue - -async def work_item_completed(args, work_hash: str) -> bool: - # Check if work item has already been completed - output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl') - bucket, key = parse_s3_path(output_s3_path) - - try: - # Check if the output file already exists - await asyncio.to_thread(workspace_s3.head_object, Bucket=bucket, Key=key) - return True - except workspace_s3.exceptions.ClientError as e: - pass - - return False - - async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult: COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions" MAX_RETRIES = 3 @@ -442,31 +289,32 @@ def build_dolma_document(pdf_s3_path, page_results): } return dolma_doc -async def worker(args, queue, semaphore, worker_id): + +async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id): while True: - [work_hash, pdfs] = await queue.get() + # Wait until allowed to proceed + await semaphore.acquire() - try: - await tracker.clear_work(worker_id) + work_item = await work_queue.get_work() - # Wait until allowed to proceed - await semaphore.acquire() + if work_item is None: + logger.info(f"Worker {worker_id} exiting due to empty queue") + semaphore.release() + break - if await work_item_completed(args, work_hash): - logger.info(f"Work {work_hash} was already completed, skipping") - continue - else: - logger.info(f"Proceeding with {work_hash} on worker {worker_id}") + logger.info(f"Worker {worker_id} processing work item {work_item.hash}") + await tracker.clear_work(worker_id) + try: async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600), connector=aiohttp.TCPConnector(limit=1000)) as session: async with asyncio.TaskGroup() as tg: - dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in pdfs] - logger.info(f"Created all tasks for {work_hash}") + dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in work_item.s3_work_paths] + logger.info(f"Created all tasks for {work_item.hash}") - logger.info(f"Finished TaskGroup for worker on {work_hash}") + logger.info(f"Finished TaskGroup for worker on {work_item.hash}") - logger.info(f"Closed ClientSession for {work_hash}") + logger.info(f"Closed ClientSession for {work_item.hash}") dolma_docs = [] for task in dolma_tasks: @@ -479,7 +327,7 @@ async def worker(args, queue, semaphore, worker_id): if result is not None: dolma_docs.append(result) - logger.info(f"Got {len(dolma_docs)} docs for {work_hash}") + logger.info(f"Got {len(dolma_docs)} docs for {work_item.hash}") # Write the Dolma documents to a local temporary file in JSONL format with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf: @@ -489,7 +337,7 @@ async def worker(args, queue, semaphore, worker_id): tf.flush() # Define the output S3 path using the work_hash - output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl') + output_s3_path = os.path.join(args.workspace, 'results', f'output_{work_item.hash}.jsonl') bucket, key = parse_s3_path(output_s3_path) workspace_s3.upload_file(tf.name, bucket, key) @@ -501,9 +349,10 @@ async def worker(args, queue, semaphore, worker_id): # Update last batch time last_batch_time = time.perf_counter() except Exception as e: - logger.exception(f"Exception occurred while processing work_hash {work_hash}: {e}") + logger.exception(f"Exception occurred while processing work_hash {work_item.hash}: {e}") finally: - queue.task_done() + await work_queue.mark_done(work_item) + semaphore.release() async def sglang_server_task(args, semaphore): @@ -563,6 +412,7 @@ async def sglang_server_task(args, semaphore): if not server_printed_ready_message and "The server is fired up and ready to roll!" in line: server_printed_ready_message = True + last_semaphore_release = time.time() match = re.search(r'#running-req: (\d+)', line) if match: @@ -631,10 +481,10 @@ async def sglang_server_ready(): raise Exception("sglang server did not become ready after waiting.") -async def metrics_reporter(queue): +async def metrics_reporter(work_queue): while True: # Leading newlines preserve table formatting in logs - logger.info(f"Queue remaining: {queue.qsize()}") + logger.info(f"Queue remaining: {work_queue.size}") logger.info("\n" + str(metrics)) logger.info("\n" + str(await tracker.get_status_table())) await asyncio.sleep(10) @@ -716,13 +566,14 @@ def submit_beaker_job(args): print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}") + def print_stats(args): import concurrent.futures from tqdm import tqdm # Get total work items and completed items - index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") - output_glob = os.path.join(args.workspace, "dolma_documents", "*.jsonl") + index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd") + output_glob = os.path.join(args.workspace, "results", "*.jsonl") work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path) done_work_items = expand_s3_glob(workspace_s3, output_glob) @@ -825,9 +676,54 @@ async def main(): check_poppler_version() + # Create work queue + work_queue = S3WorkQueue(workspace_s3, args.workspace) + if args.pdfs: logger.info("Got --pdfs argument, going to add to the work queue") - await populate_pdf_work_queue(args) + + # Expand s3 paths + if args.pdfs.startswith("s3://"): + logger.info(f"Expanding s3 glob at {args.pdfs}") + s3_work_paths = expand_s3_glob(pdf_s3, args.pdfs) + elif os.path.exists(args.pdfs): + logger.info(f"Loading file at {args.pdfs}") + with open(args.pdfs, "r") as f: + s3_work_paths = list(filter(None, (line.strip() for line in f))) + else: + raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)") + + s3_work_paths = set(s3_work_paths) + logger.info(f"Found {len(s3_work_paths):,} total pdf paths to add") + + # Estimate average pages per pdf + sample_size = min(100, len(s3_work_paths)) + sampled_pdfs = random.sample(list(s3_work_paths), sample_size) + page_counts = [] + + for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"): + try: + # Download the PDF to a temp file + with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file: + s3_bucket, s3_key = parse_s3_path(pdf) + pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file) + tmp_file.flush() + reader = PdfReader(tmp_file.name) + page_counts.append(len(reader.pages)) + except Exception as e: + logger.warning(f"Failed to read {pdf}: {e}") + + if page_counts: + avg_pages_per_pdf = sum(page_counts) / len(page_counts) + else: + logger.warning("Could not read any PDFs to estimate average page count.") + avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails + + items_per_group = max(1, int(args.pages_per_group / avg_pages_per_pdf)) + logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}") + + # Now call populate_queue + await work_queue.populate_queue(s3_work_paths, items_per_group) if args.stats: print_stats(args) @@ -839,6 +735,9 @@ async def main(): logger.info(f"Starting pipeline with PID {os.getpid()}") + # Initialize the work queue + await work_queue.initialize_queue() + # Create a semaphore to control worker access # We only allow one worker to move forward with requests, until the server has no more requests in its queue # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible @@ -847,9 +746,6 @@ async def main(): sglang_server = asyncio.create_task(sglang_server_host(args, semaphore)) - work_queue = await load_pdf_work_queue(args) - logger.info(f"Work queue prepared with {work_queue.qsize()} items") - await sglang_server_ready() metrics_task = asyncio.create_task(metrics_reporter(work_queue)) @@ -860,16 +756,9 @@ async def main(): task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i)) worker_tasks.append(task) - # Wait for the queue to be fully processed - await work_queue.join() + # Wait for all worker tasks to finish + await asyncio.gather(*worker_tasks) - # Cancel our worker tasks. - for task in worker_tasks: - task.cancel() - - # Wait until all worker tasks are cancelled. - await asyncio.gather(*worker_tasks, return_exceptions=True) - # Wait for server to stop process_pool.shutdown(wait=False) @@ -877,11 +766,12 @@ async def main(): metrics_task.cancel() logger.info("Work done") + if __name__ == "__main__": asyncio.run(main()) # TODO - # - Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied) + # X Refactor the work queue into its own file so it's reusable and generic, and it makes temporary work files (prevent issue where if a work item is done, then it stalls because queue was just emptied) # X Fix the queue release mechanism so that it just does a timeout, based on zero queue size only, so you don't block things # - Add logging of failed pages and have the stats function read them # X Add the page rotation check and mechanism diff --git a/tests/test_s3_work_queue.py b/tests/test_s3_work_queue.py index a78bd6a..c814e0a 100644 --- a/tests/test_s3_work_queue.py +++ b/tests/test_s3_work_queue.py @@ -37,12 +37,6 @@ class TestS3WorkQueue(unittest.TestCase): hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths)) self.assertEqual(hash1, hash2) - # Verify hash is actually SHA1 - sha1 = hashlib.sha1() - for path in sorted(paths): - sha1.update(path.encode('utf-8')) - self.assertEqual(hash1, sha1.hexdigest()) - def test_init(self): """Test initialization of S3WorkQueue""" client = Mock() @@ -51,7 +45,6 @@ class TestS3WorkQueue(unittest.TestCase): self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace") self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd") self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl") - self.assertIsInstance(queue._queue, asyncio.Queue) def asyncSetUp(self): """Set up async test fixtures"""