Starting on a new approach

This commit is contained in:
Jake Poznanski 2024-11-07 18:21:23 +00:00
parent faf8659028
commit 12a91ffa96

View File

@ -1,226 +1,139 @@
import logging
import argparse import argparse
import subprocess
import signal
import sys
import os
import time
import tempfile
import redis
import redis.exceptions
import random
import boto3 import boto3
import atexit import os
from pdelfin.s3_utils import expand_s3_glob from tqdm import tqdm
from urllib.parse import urlparse
import zstandard as zstd
from io import BytesIO, TextIOWrapper
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, put_s3_bytes
# Basic logging setup for now
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
# Quiet logs from pypdf
logging.getLogger("pypdf").setLevel(logging.ERROR)
# Global s3 client for the whole script, feel free to adjust params if you need it
workspace_s3 = boto3.client('s3') workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3') pdf_s3 = boto3.client('s3')
LOCK_KEY = "queue_populating"
LOCK_TIMEOUT = 30 # seconds
def populate_queue_if_empty(queue, s3_glob_path, redis_client): def download_zstd_csv(s3_client, s3_path):
""" """Download and decompress a .zstd CSV file from S3."""
Check if the queue is empty. If it is, attempt to acquire a lock to populate it. try:
Only one worker should populate the queue at a time. compressed_data = get_s3_bytes(s3_client, s3_path)
""" dctx = zstd.ZstdDecompressor()
if queue.llen("work_queue") == 0: decompressed = dctx.decompress(compressed_data)
# Attempt to acquire the lock text_stream = TextIOWrapper(BytesIO(decompressed), encoding='utf-8')
lock_acquired = redis_client.set(LOCK_KEY, "locked", nx=True, ex=LOCK_TIMEOUT) lines = text_stream.readlines()
if lock_acquired: logger.info(f"Downloaded and decompressed {s3_path}")
print("Acquired lock to populate the queue.") return lines
try: except s3_client.exceptions.NoSuchKey:
paths = expand_s3_glob(pdf_s3, s3_glob_path) logger.info(f"No existing {s3_path} found in s3, starting fresh.")
if not paths: return []
print("No paths found to populate the queue.")
return
for path in paths:
queue.rpush("work_queue", path)
print("Queue populated with initial work items.")
except Exception as e:
print(f"Error populating queue: {e}")
# Optionally, handle retry logic or alerting here
finally:
# Release the lock
redis_client.delete(LOCK_KEY)
print("Released lock after populating the queue.")
else:
print("Another worker is populating the queue. Waiting for it to complete.")
# Optionally, wait until the queue is populated
wait_for_queue_population(queue)
def wait_for_queue_population(queue, wait_time=5, max_wait=60):
"""
Wait until the queue is populated by another worker.
"""
elapsed = 0
while elapsed < max_wait:
queue_length = queue.llen("work_queue")
if queue_length > 0:
print("Queue has been populated by another worker.")
return
print(f"Waiting for queue to be populated... ({elapsed + wait_time}/{max_wait} seconds)")
time.sleep(wait_time)
elapsed += wait_time
print("Timeout waiting for queue to be populated.")
sys.exit(1)
def process(item): def upload_zstd_csv(s3_client, s3_path, lines):
# Simulate processing time between 1 and 3 seconds """Compress and upload a list of lines as a .zstd CSV file to S3."""
print(f"Processing item: {item}") joined_text = "\n".join(lines)
time.sleep(0.5) compressor = zstd.ZstdCompressor()
print(f"Completed processing item: {item}") compressed = compressor.compress(joined_text.encode('utf-8'))
put_s3_bytes(s3_client, s3_path, compressed)
logger.info(f"Uploaded compressed {s3_path}")
def get_redis_client(sentinel, master_name, leader_ip, leader_port, max_wait=60):
"""
Obtain a Redis client using Sentinel, with retry logic.
"""
elapsed = 0
wait_interval = 1 # seconds
while elapsed < max_wait:
try:
r = sentinel.master_for(master_name, socket_timeout=0.1, decode_responses=True)
r.ping()
print(f"Connected to Redis master at {leader_ip}:{leader_port}")
return r
except redis.exceptions.ConnectionError as e:
print(f"Attempt {elapsed + 1}: Unable to connect to Redis master at {leader_ip}:{leader_port}. Retrying in {wait_interval} second(s)...")
time.sleep(wait_interval)
elapsed += wait_interval
print(f"Failed to connect to Redis master at {leader_ip}:{leader_port} after {max_wait} seconds. Exiting.")
sys.exit(1)
def main(): if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Set up Redis Sentinel-based worker queue.') parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('--leader-ip', help='IP address of the initial leader node') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
parser.add_argument('--leader-port', type=int, default=6379, help='Port of the initial leader node') parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--replica', type=int, required=True, help='Replica number (0 to N-1)') parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
parser.add_argument('--add-pdfs', help='S3 glob path for work items') parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
parser.add_argument('--workers', type=int, default=10, help='Number of workers to run at a time')
args = parser.parse_args() args = parser.parse_args()
replica_number = args.replica if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.client("s3")
base_redis_port = 6379 if args.pdf_profile:
base_sentinel_port = 26379 pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3")
redis_port = base_redis_port + replica_number # Check list of pdfs and that it matches what's in the workspace
sentinel_port = base_sentinel_port + replica_number if args.pdfs:
if args.pdfs.startswith("s3://"):
if replica_number == 0: logger.info(f"Expanding s3 glob at {args.pdfs}")
leader_ip = args.leader_ip if args.leader_ip else '127.0.0.1' all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
leader_port = args.leader_port elif os.path.exists(args.pdfs):
else: logger.info(f"Loading file at {args.pdfs}")
if not args.leader_ip: with open(args.pdfs, "r") as f:
print('Error: --leader-ip is required for replica nodes (replica_number >= 1)') all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
sys.exit(1)
leader_ip = args.leader_ip
leader_port = args.leader_port
temp_dir = tempfile.mkdtemp()
redis_conf_path = os.path.join(temp_dir, 'redis.conf')
sentinel_conf_path = os.path.join(temp_dir, 'sentinel.conf')
print("Redis config path:", redis_conf_path)
with open(redis_conf_path, 'w') as f:
f.write(f'port {redis_port}\n')
f.write(f'dbfilename dump-{replica_number}.rdb\n')
f.write(f'appendfilename "appendonly-{replica_number}.aof"\n')
f.write(f'logfile "redis-{replica_number}.log"\n')
f.write(f'dir {temp_dir}\n')
if replica_number == 0:
f.write('bind 0.0.0.0\n')
else: else:
f.write(f'replicaof {leader_ip} {leader_port}\n') raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
master_name = 'mymaster' all_pdfs = set(all_pdfs)
quorum = 1 logger.info(f"Found {len(all_pdfs):,} total pdf paths")
with open(sentinel_conf_path, 'w') as f: index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
f.write(f'port {sentinel_port}\n') existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
f.write(f'dir {temp_dir}\n')
f.write(f'sentinel monitor {master_name} {leader_ip} {leader_port} {quorum}\n')
f.write(f'sentinel down-after-milliseconds {master_name} 5000\n')
f.write(f'sentinel failover-timeout {master_name} 10000\n')
f.write(f'sentinel parallel-syncs {master_name} 1\n')
redis_process = subprocess.Popen(['redis-server', redis_conf_path]) # Parse existing work items into groups
sentinel_process = subprocess.Popen(['redis-sentinel', sentinel_conf_path]) existing_groups = [line.strip().split(",") for line in existing_lines if line.strip()]
existing_pdf_set = set(pdf for group in existing_groups for pdf in group)
# Register atexit function to guarantee process termination logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")
def terminate_processes():
print("Terminating child processes...")
redis_process.terminate()
sentinel_process.terminate()
try:
redis_process.wait(timeout=5)
sentinel_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print("Forcing termination of child processes.")
redis_process.kill()
sentinel_process.kill()
print("Child processes terminated.")
atexit.register(terminate_processes) # Remove existing PDFs from all_pdfs
new_pdfs = all_pdfs - existing_pdf_set
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")
# Also handle signal-based termination # Group the new PDFs into chunks of group_size
def handle_signal(signum, frame): new_groups = []
print(f"Received signal {signum}. Terminating processes...") current_group = []
terminate_processes() for pdf in sorted(new_pdfs): # Sort for consistency
sys.exit(0) current_group.append(pdf)
if len(current_group) == args.group_size:
new_groups.append(current_group)
current_group = []
if current_group:
new_groups.append(current_group)
signal.signal(signal.SIGINT, handle_signal) logger.info(f"Created {len(new_groups):,} new work groups")
signal.signal(signal.SIGTERM, handle_signal)
time.sleep(2) # Combine existing groups with new groups
combined_groups = existing_groups + new_groups
# Use Sentinel to connect to the master # Prepare lines to write back
from redis.sentinel import Sentinel combined_lines = [",".join(group) for group in combined_groups]
sentinel = Sentinel([('127.0.0.1', sentinel_port)], socket_timeout=0.1)
# Initial connection to Redis master # Upload the combined work items back to S3
redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port) upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
# Populate the work queue if it's empty, using a distributed lock logger.info("Completed adding new PDFs.")
populate_queue_if_empty(redis_client, args.add_pdfs, redis_client)
try:
while True:
try:
# Try to get an item from the queue with a 1-minute timeout for processing
work_item = redis_client.brpoplpush("work_queue", "processing_queue", 60)
if work_item:
try:
process(work_item)
# Remove from the processing queue if processed successfully
redis_client.lrem("processing_queue", 1, work_item)
except Exception as e:
print(f"Error processing {work_item}: {e}")
# If an error occurs, let it be requeued after timeout
queue_length = redis_client.llen("work_queue") # If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
print(f"Total work items in queue: {queue_length}") # If not, then your job is to do the actual work
time.sleep(0.1) # Start up the sglang server
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError) as e: # Read in the work queue from s3
print("Lost connection to Redis. Attempting to reconnect using Sentinel...") # Read in the done items from the s3 workspace
# Attempt to reconnect using Sentinel
while True:
try:
redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port)
print("Reconnected to Redis master.")
break # Exit the reconnection loop and resume work
except redis.exceptions.ConnectionError:
print("Reconnection failed. Retrying in 5 seconds...")
time.sleep(5)
except Exception as e:
print(f"Unexpected error: {e}")
handle_signal(None, None)
except KeyboardInterrupt: # Spawn up to N workers to do:
handle_signal(None, None) # In a loop, take a random work item, read in the pdfs, queue in their requests
# Get results back, retry any failed pages
# Check periodically if that work is done in s3, if so, then abandon this work
# Save results back to s3 workspace output folder
if __name__ == '__main__': # Possible future addon, in beaker, discover other nodes on this same job
main() # Send them a message when you take a work item off the queue