Starting on a new approach

This commit is contained in:
Jake Poznanski 2024-11-07 18:21:23 +00:00
parent faf8659028
commit 12a91ffa96

View File

@ -1,226 +1,139 @@
import logging
import argparse import argparse
import subprocess
import signal
import sys
import os
import time
import tempfile
import redis
import redis.exceptions
import random
import boto3 import boto3
import atexit import os
from pdelfin.s3_utils import expand_s3_glob from tqdm import tqdm
from urllib.parse import urlparse
import zstandard as zstd
from io import BytesIO, TextIOWrapper
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, put_s3_bytes
# Basic logging setup for now
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
# Quiet logs from pypdf
logging.getLogger("pypdf").setLevel(logging.ERROR)
# Global s3 client for the whole script, feel free to adjust params if you need it
workspace_s3 = boto3.client('s3') workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3') pdf_s3 = boto3.client('s3')
LOCK_KEY = "queue_populating"
LOCK_TIMEOUT = 30 # seconds
def populate_queue_if_empty(queue, s3_glob_path, redis_client): def download_zstd_csv(s3_client, s3_path):
""" """Download and decompress a .zstd CSV file from S3."""
Check if the queue is empty. If it is, attempt to acquire a lock to populate it. try:
Only one worker should populate the queue at a time. compressed_data = get_s3_bytes(s3_client, s3_path)
""" dctx = zstd.ZstdDecompressor()
if queue.llen("work_queue") == 0: decompressed = dctx.decompress(compressed_data)
# Attempt to acquire the lock text_stream = TextIOWrapper(BytesIO(decompressed), encoding='utf-8')
lock_acquired = redis_client.set(LOCK_KEY, "locked", nx=True, ex=LOCK_TIMEOUT) lines = text_stream.readlines()
if lock_acquired: logger.info(f"Downloaded and decompressed {s3_path}")
print("Acquired lock to populate the queue.") return lines
try: except s3_client.exceptions.NoSuchKey:
paths = expand_s3_glob(pdf_s3, s3_glob_path) logger.info(f"No existing {s3_path} found in s3, starting fresh.")
if not paths: return []
print("No paths found to populate the queue.")
return
for path in paths:
queue.rpush("work_queue", path)
print("Queue populated with initial work items.")
except Exception as e:
print(f"Error populating queue: {e}")
# Optionally, handle retry logic or alerting here
finally:
# Release the lock
redis_client.delete(LOCK_KEY)
print("Released lock after populating the queue.")
else:
print("Another worker is populating the queue. Waiting for it to complete.")
# Optionally, wait until the queue is populated
wait_for_queue_population(queue)
def wait_for_queue_population(queue, wait_time=5, max_wait=60):
"""
Wait until the queue is populated by another worker.
"""
elapsed = 0
while elapsed < max_wait:
queue_length = queue.llen("work_queue")
if queue_length > 0:
print("Queue has been populated by another worker.")
return
print(f"Waiting for queue to be populated... ({elapsed + wait_time}/{max_wait} seconds)")
time.sleep(wait_time)
elapsed += wait_time
print("Timeout waiting for queue to be populated.")
sys.exit(1)
def process(item): def upload_zstd_csv(s3_client, s3_path, lines):
# Simulate processing time between 1 and 3 seconds """Compress and upload a list of lines as a .zstd CSV file to S3."""
print(f"Processing item: {item}") joined_text = "\n".join(lines)
time.sleep(0.5) compressor = zstd.ZstdCompressor()
print(f"Completed processing item: {item}") compressed = compressor.compress(joined_text.encode('utf-8'))
put_s3_bytes(s3_client, s3_path, compressed)
logger.info(f"Uploaded compressed {s3_path}")
def get_redis_client(sentinel, master_name, leader_ip, leader_port, max_wait=60):
"""
Obtain a Redis client using Sentinel, with retry logic.
"""
elapsed = 0
wait_interval = 1 # seconds
while elapsed < max_wait:
try:
r = sentinel.master_for(master_name, socket_timeout=0.1, decode_responses=True)
r.ping()
print(f"Connected to Redis master at {leader_ip}:{leader_port}")
return r
except redis.exceptions.ConnectionError as e:
print(f"Attempt {elapsed + 1}: Unable to connect to Redis master at {leader_ip}:{leader_port}. Retrying in {wait_interval} second(s)...")
time.sleep(wait_interval)
elapsed += wait_interval
print(f"Failed to connect to Redis master at {leader_ip}:{leader_port} after {max_wait} seconds. Exiting.")
sys.exit(1)
def main(): if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Set up Redis Sentinel-based worker queue.') parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('--leader-ip', help='IP address of the initial leader node') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
parser.add_argument('--leader-port', type=int, default=6379, help='Port of the initial leader node') parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--replica', type=int, required=True, help='Replica number (0 to N-1)') parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
parser.add_argument('--add-pdfs', help='S3 glob path for work items') parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
parser.add_argument('--workers', type=int, default=10, help='Number of workers to run at a time')
args = parser.parse_args() args = parser.parse_args()
replica_number = args.replica if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.client("s3")
base_redis_port = 6379 if args.pdf_profile:
base_sentinel_port = 26379 pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3")
redis_port = base_redis_port + replica_number # Check list of pdfs and that it matches what's in the workspace
sentinel_port = base_sentinel_port + replica_number if args.pdfs:
if args.pdfs.startswith("s3://"):
if replica_number == 0: logger.info(f"Expanding s3 glob at {args.pdfs}")
leader_ip = args.leader_ip if args.leader_ip else '127.0.0.1' all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
leader_port = args.leader_port elif os.path.exists(args.pdfs):
else: logger.info(f"Loading file at {args.pdfs}")
if not args.leader_ip: with open(args.pdfs, "r") as f:
print('Error: --leader-ip is required for replica nodes (replica_number >= 1)') all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
sys.exit(1)
leader_ip = args.leader_ip
leader_port = args.leader_port
temp_dir = tempfile.mkdtemp()
redis_conf_path = os.path.join(temp_dir, 'redis.conf')
sentinel_conf_path = os.path.join(temp_dir, 'sentinel.conf')
print("Redis config path:", redis_conf_path)
with open(redis_conf_path, 'w') as f:
f.write(f'port {redis_port}\n')
f.write(f'dbfilename dump-{replica_number}.rdb\n')
f.write(f'appendfilename "appendonly-{replica_number}.aof"\n')
f.write(f'logfile "redis-{replica_number}.log"\n')
f.write(f'dir {temp_dir}\n')
if replica_number == 0:
f.write('bind 0.0.0.0\n')
else: else:
f.write(f'replicaof {leader_ip} {leader_port}\n') raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
master_name = 'mymaster' all_pdfs = set(all_pdfs)
quorum = 1 logger.info(f"Found {len(all_pdfs):,} total pdf paths")
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
with open(sentinel_conf_path, 'w') as f: # Parse existing work items into groups
f.write(f'port {sentinel_port}\n') existing_groups = [line.strip().split(",") for line in existing_lines if line.strip()]
f.write(f'dir {temp_dir}\n') existing_pdf_set = set(pdf for group in existing_groups for pdf in group)
f.write(f'sentinel monitor {master_name} {leader_ip} {leader_port} {quorum}\n')
f.write(f'sentinel down-after-milliseconds {master_name} 5000\n')
f.write(f'sentinel failover-timeout {master_name} 10000\n')
f.write(f'sentinel parallel-syncs {master_name} 1\n')
redis_process = subprocess.Popen(['redis-server', redis_conf_path]) logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")
sentinel_process = subprocess.Popen(['redis-sentinel', sentinel_conf_path])
# Register atexit function to guarantee process termination # Remove existing PDFs from all_pdfs
def terminate_processes(): new_pdfs = all_pdfs - existing_pdf_set
print("Terminating child processes...") logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")
redis_process.terminate()
sentinel_process.terminate()
try:
redis_process.wait(timeout=5)
sentinel_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print("Forcing termination of child processes.")
redis_process.kill()
sentinel_process.kill()
print("Child processes terminated.")
atexit.register(terminate_processes) # Group the new PDFs into chunks of group_size
new_groups = []
current_group = []
for pdf in sorted(new_pdfs): # Sort for consistency
current_group.append(pdf)
if len(current_group) == args.group_size:
new_groups.append(current_group)
current_group = []
if current_group:
new_groups.append(current_group)
# Also handle signal-based termination logger.info(f"Created {len(new_groups):,} new work groups")
def handle_signal(signum, frame):
print(f"Received signal {signum}. Terminating processes...")
terminate_processes()
sys.exit(0)
signal.signal(signal.SIGINT, handle_signal) # Combine existing groups with new groups
signal.signal(signal.SIGTERM, handle_signal) combined_groups = existing_groups + new_groups
time.sleep(2) # Prepare lines to write back
combined_lines = [",".join(group) for group in combined_groups]
# Use Sentinel to connect to the master # Upload the combined work items back to S3
from redis.sentinel import Sentinel upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
sentinel = Sentinel([('127.0.0.1', sentinel_port)], socket_timeout=0.1)
# Initial connection to Redis master logger.info("Completed adding new PDFs.")
redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port)
# Populate the work queue if it's empty, using a distributed lock
populate_queue_if_empty(redis_client, args.add_pdfs, redis_client)
try: # If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
while True: # If not, then your job is to do the actual work
try:
# Try to get an item from the queue with a 1-minute timeout for processing
work_item = redis_client.brpoplpush("work_queue", "processing_queue", 60)
if work_item:
try:
process(work_item)
# Remove from the processing queue if processed successfully
redis_client.lrem("processing_queue", 1, work_item)
except Exception as e:
print(f"Error processing {work_item}: {e}")
# If an error occurs, let it be requeued after timeout
queue_length = redis_client.llen("work_queue") # Start up the sglang server
print(f"Total work items in queue: {queue_length}")
time.sleep(0.1) # Read in the work queue from s3
# Read in the done items from the s3 workspace
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError) as e: # Spawn up to N workers to do:
print("Lost connection to Redis. Attempting to reconnect using Sentinel...") # In a loop, take a random work item, read in the pdfs, queue in their requests
# Attempt to reconnect using Sentinel # Get results back, retry any failed pages
while True: # Check periodically if that work is done in s3, if so, then abandon this work
try: # Save results back to s3 workspace output folder
redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port)
print("Reconnected to Redis master.")
break # Exit the reconnection loop and resume work
except redis.exceptions.ConnectionError:
print("Reconnection failed. Retrying in 5 seconds...")
time.sleep(5)
except Exception as e:
print(f"Unexpected error: {e}")
handle_signal(None, None)
except KeyboardInterrupt: # Possible future addon, in beaker, discover other nodes on this same job
handle_signal(None, None) # Send them a message when you take a work item off the queue
if __name__ == '__main__':
main()