diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index 7268ad3..1432b3b 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -1,226 +1,139 @@ +import logging import argparse -import subprocess -import signal -import sys -import os -import time -import tempfile -import redis -import redis.exceptions -import random import boto3 -import atexit +import os -from pdelfin.s3_utils import expand_s3_glob +from tqdm import tqdm +from urllib.parse import urlparse +import zstandard as zstd +from io import BytesIO, TextIOWrapper +from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, put_s3_bytes + +# Basic logging setup for now +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) + +# Quiet logs from pypdf +logging.getLogger("pypdf").setLevel(logging.ERROR) + +# Global s3 client for the whole script, feel free to adjust params if you need it workspace_s3 = boto3.client('s3') pdf_s3 = boto3.client('s3') -LOCK_KEY = "queue_populating" -LOCK_TIMEOUT = 30 # seconds -def populate_queue_if_empty(queue, s3_glob_path, redis_client): - """ - Check if the queue is empty. If it is, attempt to acquire a lock to populate it. - Only one worker should populate the queue at a time. - """ - if queue.llen("work_queue") == 0: - # Attempt to acquire the lock - lock_acquired = redis_client.set(LOCK_KEY, "locked", nx=True, ex=LOCK_TIMEOUT) - if lock_acquired: - print("Acquired lock to populate the queue.") - try: - paths = expand_s3_glob(pdf_s3, s3_glob_path) - if not paths: - print("No paths found to populate the queue.") - return - for path in paths: - queue.rpush("work_queue", path) - print("Queue populated with initial work items.") - except Exception as e: - print(f"Error populating queue: {e}") - # Optionally, handle retry logic or alerting here - finally: - # Release the lock - redis_client.delete(LOCK_KEY) - print("Released lock after populating the queue.") - else: - print("Another worker is populating the queue. Waiting for it to complete.") - # Optionally, wait until the queue is populated - wait_for_queue_population(queue) +def download_zstd_csv(s3_client, s3_path): + """Download and decompress a .zstd CSV file from S3.""" + try: + compressed_data = get_s3_bytes(s3_client, s3_path) + dctx = zstd.ZstdDecompressor() + decompressed = dctx.decompress(compressed_data) + text_stream = TextIOWrapper(BytesIO(decompressed), encoding='utf-8') + lines = text_stream.readlines() + logger.info(f"Downloaded and decompressed {s3_path}") + return lines + except s3_client.exceptions.NoSuchKey: + logger.info(f"No existing {s3_path} found in s3, starting fresh.") + return [] -def wait_for_queue_population(queue, wait_time=5, max_wait=60): - """ - Wait until the queue is populated by another worker. - """ - elapsed = 0 - while elapsed < max_wait: - queue_length = queue.llen("work_queue") - if queue_length > 0: - print("Queue has been populated by another worker.") - return - print(f"Waiting for queue to be populated... ({elapsed + wait_time}/{max_wait} seconds)") - time.sleep(wait_time) - elapsed += wait_time - print("Timeout waiting for queue to be populated.") - sys.exit(1) -def process(item): - # Simulate processing time between 1 and 3 seconds - print(f"Processing item: {item}") - time.sleep(0.5) - print(f"Completed processing item: {item}") +def upload_zstd_csv(s3_client, s3_path, lines): + """Compress and upload a list of lines as a .zstd CSV file to S3.""" + joined_text = "\n".join(lines) + compressor = zstd.ZstdCompressor() + compressed = compressor.compress(joined_text.encode('utf-8')) + put_s3_bytes(s3_client, s3_path, compressed) + logger.info(f"Uploaded compressed {s3_path}") -def get_redis_client(sentinel, master_name, leader_ip, leader_port, max_wait=60): - """ - Obtain a Redis client using Sentinel, with retry logic. - """ - elapsed = 0 - wait_interval = 1 # seconds - while elapsed < max_wait: - try: - r = sentinel.master_for(master_name, socket_timeout=0.1, decode_responses=True) - r.ping() - print(f"Connected to Redis master at {leader_ip}:{leader_port}") - return r - except redis.exceptions.ConnectionError as e: - print(f"Attempt {elapsed + 1}: Unable to connect to Redis master at {leader_ip}:{leader_port}. Retrying in {wait_interval} second(s)...") - time.sleep(wait_interval) - elapsed += wait_interval - print(f"Failed to connect to Redis master at {leader_ip}:{leader_port} after {max_wait} seconds. Exiting.") - sys.exit(1) -def main(): - parser = argparse.ArgumentParser(description='Set up Redis Sentinel-based worker queue.') - parser.add_argument('--leader-ip', help='IP address of the initial leader node') - parser.add_argument('--leader-port', type=int, default=6379, help='Port of the initial leader node') - parser.add_argument('--replica', type=int, required=True, help='Replica number (0 to N-1)') - parser.add_argument('--add-pdfs', help='S3 glob path for work items') +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') + parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/') + parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None) + parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024) + parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000) + parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None) + parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None) + parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.') + parser.add_argument('--workers', type=int, default=10, help='Number of workers to run at a time') args = parser.parse_args() - replica_number = args.replica + if args.workspace_profile: + workspace_session = boto3.Session(profile_name=args.workspace_profile) + workspace_s3 = workspace_session.client("s3") - base_redis_port = 6379 - base_sentinel_port = 26379 + if args.pdf_profile: + pdf_session = boto3.Session(profile_name=args.pdf_profile) + pdf_s3 = pdf_session.client("s3") - redis_port = base_redis_port + replica_number - sentinel_port = base_sentinel_port + replica_number - - if replica_number == 0: - leader_ip = args.leader_ip if args.leader_ip else '127.0.0.1' - leader_port = args.leader_port - else: - if not args.leader_ip: - print('Error: --leader-ip is required for replica nodes (replica_number >= 1)') - sys.exit(1) - leader_ip = args.leader_ip - leader_port = args.leader_port - - temp_dir = tempfile.mkdtemp() - redis_conf_path = os.path.join(temp_dir, 'redis.conf') - sentinel_conf_path = os.path.join(temp_dir, 'sentinel.conf') - - print("Redis config path:", redis_conf_path) - - with open(redis_conf_path, 'w') as f: - f.write(f'port {redis_port}\n') - f.write(f'dbfilename dump-{replica_number}.rdb\n') - f.write(f'appendfilename "appendonly-{replica_number}.aof"\n') - f.write(f'logfile "redis-{replica_number}.log"\n') - f.write(f'dir {temp_dir}\n') - if replica_number == 0: - f.write('bind 0.0.0.0\n') + # Check list of pdfs and that it matches what's in the workspace + if args.pdfs: + if args.pdfs.startswith("s3://"): + logger.info(f"Expanding s3 glob at {args.pdfs}") + all_pdfs = expand_s3_glob(pdf_s3, args.pdfs) + elif os.path.exists(args.pdfs): + logger.info(f"Loading file at {args.pdfs}") + with open(args.pdfs, "r") as f: + all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs")))) else: - f.write(f'replicaof {leader_ip} {leader_port}\n') + raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)") - master_name = 'mymaster' - quorum = 1 + all_pdfs = set(all_pdfs) + logger.info(f"Found {len(all_pdfs):,} total pdf paths") + + index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") + existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path) - with open(sentinel_conf_path, 'w') as f: - f.write(f'port {sentinel_port}\n') - f.write(f'dir {temp_dir}\n') - f.write(f'sentinel monitor {master_name} {leader_ip} {leader_port} {quorum}\n') - f.write(f'sentinel down-after-milliseconds {master_name} 5000\n') - f.write(f'sentinel failover-timeout {master_name} 10000\n') - f.write(f'sentinel parallel-syncs {master_name} 1\n') + # Parse existing work items into groups + existing_groups = [line.strip().split(",") for line in existing_lines if line.strip()] + existing_pdf_set = set(pdf for group in existing_groups for pdf in group) - redis_process = subprocess.Popen(['redis-server', redis_conf_path]) - sentinel_process = subprocess.Popen(['redis-sentinel', sentinel_conf_path]) + logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace") - # Register atexit function to guarantee process termination - def terminate_processes(): - print("Terminating child processes...") - redis_process.terminate() - sentinel_process.terminate() - try: - redis_process.wait(timeout=5) - sentinel_process.wait(timeout=5) - except subprocess.TimeoutExpired: - print("Forcing termination of child processes.") - redis_process.kill() - sentinel_process.kill() - print("Child processes terminated.") + # Remove existing PDFs from all_pdfs + new_pdfs = all_pdfs - existing_pdf_set + logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace") - atexit.register(terminate_processes) + # Group the new PDFs into chunks of group_size + new_groups = [] + current_group = [] + for pdf in sorted(new_pdfs): # Sort for consistency + current_group.append(pdf) + if len(current_group) == args.group_size: + new_groups.append(current_group) + current_group = [] + if current_group: + new_groups.append(current_group) - # Also handle signal-based termination - def handle_signal(signum, frame): - print(f"Received signal {signum}. Terminating processes...") - terminate_processes() - sys.exit(0) + logger.info(f"Created {len(new_groups):,} new work groups") - signal.signal(signal.SIGINT, handle_signal) - signal.signal(signal.SIGTERM, handle_signal) + # Combine existing groups with new groups + combined_groups = existing_groups + new_groups - time.sleep(2) + # Prepare lines to write back + combined_lines = [",".join(group) for group in combined_groups] - # Use Sentinel to connect to the master - from redis.sentinel import Sentinel - sentinel = Sentinel([('127.0.0.1', sentinel_port)], socket_timeout=0.1) + # Upload the combined work items back to S3 + upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines) - # Initial connection to Redis master - redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port) + logger.info("Completed adding new PDFs.") - # Populate the work queue if it's empty, using a distributed lock - populate_queue_if_empty(redis_client, args.add_pdfs, redis_client) - try: - while True: - try: - # Try to get an item from the queue with a 1-minute timeout for processing - work_item = redis_client.brpoplpush("work_queue", "processing_queue", 60) - if work_item: - try: - process(work_item) - # Remove from the processing queue if processed successfully - redis_client.lrem("processing_queue", 1, work_item) - except Exception as e: - print(f"Error processing {work_item}: {e}") - # If an error occurs, let it be requeued after timeout + # If there is a beaker flag, then your job is to trigger this script with N replicas on beaker + # If not, then your job is to do the actual work - queue_length = redis_client.llen("work_queue") - print(f"Total work items in queue: {queue_length}") + # Start up the sglang server - time.sleep(0.1) + # Read in the work queue from s3 + # Read in the done items from the s3 workspace - except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError) as e: - print("Lost connection to Redis. Attempting to reconnect using Sentinel...") - # Attempt to reconnect using Sentinel - while True: - try: - redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port) - print("Reconnected to Redis master.") - break # Exit the reconnection loop and resume work - except redis.exceptions.ConnectionError: - print("Reconnection failed. Retrying in 5 seconds...") - time.sleep(5) - except Exception as e: - print(f"Unexpected error: {e}") - handle_signal(None, None) + # Spawn up to N workers to do: + # In a loop, take a random work item, read in the pdfs, queue in their requests + # Get results back, retry any failed pages + # Check periodically if that work is done in s3, if so, then abandon this work + # Save results back to s3 workspace output folder - except KeyboardInterrupt: - handle_signal(None, None) - -if __name__ == '__main__': - main() + # Possible future addon, in beaker, discover other nodes on this same job + # Send them a message when you take a work item off the queue