diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index b62241f..5673c36 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -9,6 +9,7 @@ import subprocess import atexit import hashlib import base64 +import asyncio from tqdm import tqdm from io import BytesIO @@ -73,8 +74,123 @@ def compute_workgroup_sha1(work_group: list[str]) -> str: sha1.update(pdf.encode('utf-8')) return sha1.hexdigest() +async def start_sglang_server(args): + model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model') + download_directory(args.model, model_cache_dir) -if __name__ == '__main__': + # Start up the sglang server + sglang_process = subprocess.Popen([ + "python3", "-m", "sglang.launch_server", + "--model-path", model_cache_dir, + "--chat-template", args.model_chat_template, + "--context-length", str(args.model_max_context), + ]) + +async def populate_pdf_work_queue(args): + index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") + + if args.pdfs.startswith("s3://"): + logger.info(f"Expanding s3 glob at {args.pdfs}") + all_pdfs = expand_s3_glob(pdf_s3, args.pdfs) + elif os.path.exists(args.pdfs): + logger.info(f"Loading file at {args.pdfs}") + with open(args.pdfs, "r") as f: + all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs")))) + else: + raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)") + + all_pdfs = set(all_pdfs) + logger.info(f"Found {len(all_pdfs):,} total pdf paths") + + existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path) + + # Parse existing work items into groups + existing_groups = {} + for line in existing_lines: + if line.strip(): + parts = line.strip().split(",") + group_hash = parts[0] + group_pdfs = parts[1:] + existing_groups[group_hash] = group_pdfs + existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs) + + logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace") + + # Remove existing PDFs from all_pdfs + new_pdfs = all_pdfs - existing_pdf_set + logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace") + + # Group the new PDFs into chunks of group_size + # TODO: Figure out the group size automatically by sampling a few pdfs, and taking the mean/median number of pages, etc. + new_groups = [] + current_group = [] + for pdf in sorted(new_pdfs): # Sort for consistency + current_group.append(pdf) + if len(current_group) == args.group_size: + group_hash = compute_workgroup_sha1(current_group) + new_groups.append((group_hash, current_group)) + current_group = [] + if current_group: + group_hash = compute_workgroup_sha1(current_group) + new_groups.append((group_hash, current_group)) + + logger.info(f"Created {len(new_groups):,} new work groups") + + # Combine existing groups with new groups + combined_groups = existing_groups.copy() + for group_hash, group_pdfs in new_groups: + combined_groups[group_hash] = group_pdfs + + # Prepare lines to write back + combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()] + + # Upload the combined work items back to S3 + if new_groups: + upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines) + + logger.info("Completed adding new PDFs.") + +async def load_pdf_work_queue(args) -> asyncio.Queue: + index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") + + # Read in the work queue from s3 + work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path) + work_queue = {} + for line in work_queue_lines: + if line.strip(): + parts = line.strip().split(",") + group_hash = parts[0] + group_pdfs = parts[1:] + work_queue[group_hash] = group_pdfs + + # Read in the done items from the s3 workspace + done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl") + done_work_hashes = set() + for item in done_work_items: + filename = os.path.basename(item) + if filename.startswith('output_') and filename.endswith('.jsonl'): + group_hash = filename[len('output_'):-len('.jsonl')] + done_work_hashes.add(group_hash) + + remaining_work_hashes = set(work_queue.keys()) - done_work_hashes + remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes} + + queue = asyncio.Queue() + + for work in remaining_work_queue: + await queue.put((work, remaining_work_queue[work])) + + return queue + +async def worker(args, queue): + while True: + work = await queue.get() + + logger.info(f"Got work to do for {work}") + queue.task_done() + + +async def main(): parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/') parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None) @@ -101,131 +217,67 @@ if __name__ == '__main__': pdf_session = boto3.Session(profile_name=args.pdf_profile) pdf_s3 = pdf_session.client("s3") - index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") check_poppler_version() - # Check list of pdfs and that it matches what's in the workspace if args.pdfs: - if args.pdfs.startswith("s3://"): - logger.info(f"Expanding s3 glob at {args.pdfs}") - all_pdfs = expand_s3_glob(pdf_s3, args.pdfs) - elif os.path.exists(args.pdfs): - logger.info(f"Loading file at {args.pdfs}") - with open(args.pdfs, "r") as f: - all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs")))) - else: - raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)") + await populate_pdf_work_queue(args) - all_pdfs = set(all_pdfs) - logger.info(f"Found {len(all_pdfs):,} total pdf paths") + work_queue = await load_pdf_work_queue(args) + logger.info(f"Work queue prepared with {work_queue.qsize()} items") - existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path) - # Parse existing work items into groups - existing_groups = {} - for line in existing_lines: - if line.strip(): - parts = line.strip().split(",") - group_hash = parts[0] - group_pdfs = parts[1:] - existing_groups[group_hash] = group_pdfs - existing_pdf_set = set(pdf for group_pdfs in existing_groups.values() for pdf in group_pdfs) + # Create worker tasks to process the queue concurrently. + tasks = [] + for i in range(args.workers): + task = asyncio.create_task(worker(args, work_queue)) + tasks.append(task) - logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace") + # Wait for the queue to be fully processed + await work_queue.join() - # Remove existing PDFs from all_pdfs - new_pdfs = all_pdfs - existing_pdf_set - logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace") + # Cancel our worker tasks. + for task in tasks: + task.cancel() - # Group the new PDFs into chunks of group_size - # TODO: Figure out the group size automatically by sampling a few pdfs, and taking the mean/median number of pages, etc. - new_groups = [] - current_group = [] - for pdf in sorted(new_pdfs): # Sort for consistency - current_group.append(pdf) - if len(current_group) == args.group_size: - group_hash = compute_workgroup_sha1(current_group) - new_groups.append((group_hash, current_group)) - current_group = [] - if current_group: - group_hash = compute_workgroup_sha1(current_group) - new_groups.append((group_hash, current_group)) + # Wait until all worker tasks are cancelled. + await asyncio.gather(*tasks, return_exceptions=True) + - logger.info(f"Created {len(new_groups):,} new work groups") +if __name__ == "__main__": + asyncio.run(main()) - # Combine existing groups with new groups - combined_groups = existing_groups.copy() - for group_hash, group_pdfs in new_groups: - combined_groups[group_hash] = group_pdfs - # Prepare lines to write back - combined_lines = [",".join([group_hash] + group_pdfs) for group_hash, group_pdfs in combined_groups.items()] - # Upload the combined work items back to S3 - if new_groups: - upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines) - - logger.info("Completed adding new PDFs.") # TODO # If there is a beaker flag, then your job is to trigger this script with N replicas on beaker # If not, then your job is to do the actual work # Download the model from the best place available - model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model') - download_directory(args.model, model_cache_dir) - - # Start up the sglang server - sglang_process = subprocess.Popen([ - "python3", "-m", "sglang.launch_server", - "--model-path", model_cache_dir, - "--chat-template", args.model_chat_template, - "--context-length", str(args.model_max_context), - ]) + # Register atexit function and signal handlers to guarantee process termination - def terminate_processes(): - print("Terminating child processes...") - sglang_process.terminate() - try: - sglang_process.wait(timeout=30) - except subprocess.TimeoutExpired: - print("Forcing termination of child processes.") - sglang_process.kill() - print("Child processes terminated.") + # def terminate_processes(): + # print("Terminating child processes...") + # sglang_process.terminate() + # try: + # sglang_process.wait(timeout=30) + # except subprocess.TimeoutExpired: + # print("Forcing termination of child processes.") + # sglang_process.kill() + # print("Child processes terminated.") - atexit.register(terminate_processes) + # atexit.register(terminate_processes) - def signal_handler(sig, frame): - terminate_processes() - sys.exit(0) + # def signal_handler(sig, frame): + # terminate_processes() + # sys.exit(0) - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) + # signal.signal(signal.SIGINT, signal_handler) + # signal.signal(signal.SIGTERM, signal_handler) - # Read in the work queue from s3 - work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path) - work_queue = {} - for line in work_queue_lines: - if line.strip(): - parts = line.strip().split(",") - group_hash = parts[0] - group_pdfs = parts[1:] - work_queue[group_hash] = group_pdfs - - # Read in the done items from the s3 workspace - done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl") - done_work_hashes = set() - for item in done_work_items: - filename = os.path.basename(item) - if filename.startswith('output_') and filename.endswith('.jsonl'): - group_hash = filename[len('output_'):-len('.jsonl')] - done_work_hashes.add(group_hash) - - remaining_work_hashes = set(work_queue.keys()) - done_work_hashes - remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes} - - logger.info(f"Remaining work items: {len(remaining_work_queue)}") + + # logger.info(f"Remaining work items: {len(remaining_work_queue)}") # TODO # Spawn up to N workers to do: @@ -238,12 +290,12 @@ if __name__ == '__main__': # Possible future addon, in beaker, discover other nodes on this same job # Send them a message when you take a work item off the queue - try: - while True: - time.sleep(1) + # try: + # while True: + # time.sleep(1) - if sglang_process.returncode is not None: - logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.") - except KeyboardInterrupt: - logger.info("Got keyboard interrupt, exiting everything") - sys.exit(1) + # if sglang_process.returncode is not None: + # logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.") + # except KeyboardInterrupt: + # logger.info("Got keyboard interrupt, exiting everything") + # sys.exit(1)