mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-02 21:34:31 +00:00
Starting on a new approach
This commit is contained in:
parent
faf8659028
commit
12a91ffa96
@ -1,226 +1,139 @@
|
|||||||
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
|
||||||
import signal
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import tempfile
|
|
||||||
import redis
|
|
||||||
import redis.exceptions
|
|
||||||
import random
|
|
||||||
import boto3
|
import boto3
|
||||||
import atexit
|
import os
|
||||||
|
|
||||||
from pdelfin.s3_utils import expand_s3_glob
|
from tqdm import tqdm
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import zstandard as zstd
|
||||||
|
from io import BytesIO, TextIOWrapper
|
||||||
|
|
||||||
|
from pdelfin.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path, put_s3_bytes
|
||||||
|
|
||||||
|
# Basic logging setup for now
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
# Quiet logs from pypdf
|
||||||
|
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
# Global s3 client for the whole script, feel free to adjust params if you need it
|
||||||
workspace_s3 = boto3.client('s3')
|
workspace_s3 = boto3.client('s3')
|
||||||
pdf_s3 = boto3.client('s3')
|
pdf_s3 = boto3.client('s3')
|
||||||
|
|
||||||
LOCK_KEY = "queue_populating"
|
|
||||||
LOCK_TIMEOUT = 30 # seconds
|
|
||||||
|
|
||||||
def populate_queue_if_empty(queue, s3_glob_path, redis_client):
|
def download_zstd_csv(s3_client, s3_path):
|
||||||
"""
|
"""Download and decompress a .zstd CSV file from S3."""
|
||||||
Check if the queue is empty. If it is, attempt to acquire a lock to populate it.
|
try:
|
||||||
Only one worker should populate the queue at a time.
|
compressed_data = get_s3_bytes(s3_client, s3_path)
|
||||||
"""
|
dctx = zstd.ZstdDecompressor()
|
||||||
if queue.llen("work_queue") == 0:
|
decompressed = dctx.decompress(compressed_data)
|
||||||
# Attempt to acquire the lock
|
text_stream = TextIOWrapper(BytesIO(decompressed), encoding='utf-8')
|
||||||
lock_acquired = redis_client.set(LOCK_KEY, "locked", nx=True, ex=LOCK_TIMEOUT)
|
lines = text_stream.readlines()
|
||||||
if lock_acquired:
|
logger.info(f"Downloaded and decompressed {s3_path}")
|
||||||
print("Acquired lock to populate the queue.")
|
return lines
|
||||||
try:
|
except s3_client.exceptions.NoSuchKey:
|
||||||
paths = expand_s3_glob(pdf_s3, s3_glob_path)
|
logger.info(f"No existing {s3_path} found in s3, starting fresh.")
|
||||||
if not paths:
|
return []
|
||||||
print("No paths found to populate the queue.")
|
|
||||||
return
|
|
||||||
for path in paths:
|
|
||||||
queue.rpush("work_queue", path)
|
|
||||||
print("Queue populated with initial work items.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error populating queue: {e}")
|
|
||||||
# Optionally, handle retry logic or alerting here
|
|
||||||
finally:
|
|
||||||
# Release the lock
|
|
||||||
redis_client.delete(LOCK_KEY)
|
|
||||||
print("Released lock after populating the queue.")
|
|
||||||
else:
|
|
||||||
print("Another worker is populating the queue. Waiting for it to complete.")
|
|
||||||
# Optionally, wait until the queue is populated
|
|
||||||
wait_for_queue_population(queue)
|
|
||||||
|
|
||||||
def wait_for_queue_population(queue, wait_time=5, max_wait=60):
|
|
||||||
"""
|
|
||||||
Wait until the queue is populated by another worker.
|
|
||||||
"""
|
|
||||||
elapsed = 0
|
|
||||||
while elapsed < max_wait:
|
|
||||||
queue_length = queue.llen("work_queue")
|
|
||||||
if queue_length > 0:
|
|
||||||
print("Queue has been populated by another worker.")
|
|
||||||
return
|
|
||||||
print(f"Waiting for queue to be populated... ({elapsed + wait_time}/{max_wait} seconds)")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
elapsed += wait_time
|
|
||||||
print("Timeout waiting for queue to be populated.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
def process(item):
|
def upload_zstd_csv(s3_client, s3_path, lines):
|
||||||
# Simulate processing time between 1 and 3 seconds
|
"""Compress and upload a list of lines as a .zstd CSV file to S3."""
|
||||||
print(f"Processing item: {item}")
|
joined_text = "\n".join(lines)
|
||||||
time.sleep(0.5)
|
compressor = zstd.ZstdCompressor()
|
||||||
print(f"Completed processing item: {item}")
|
compressed = compressor.compress(joined_text.encode('utf-8'))
|
||||||
|
put_s3_bytes(s3_client, s3_path, compressed)
|
||||||
|
logger.info(f"Uploaded compressed {s3_path}")
|
||||||
|
|
||||||
def get_redis_client(sentinel, master_name, leader_ip, leader_port, max_wait=60):
|
|
||||||
"""
|
|
||||||
Obtain a Redis client using Sentinel, with retry logic.
|
|
||||||
"""
|
|
||||||
elapsed = 0
|
|
||||||
wait_interval = 1 # seconds
|
|
||||||
while elapsed < max_wait:
|
|
||||||
try:
|
|
||||||
r = sentinel.master_for(master_name, socket_timeout=0.1, decode_responses=True)
|
|
||||||
r.ping()
|
|
||||||
print(f"Connected to Redis master at {leader_ip}:{leader_port}")
|
|
||||||
return r
|
|
||||||
except redis.exceptions.ConnectionError as e:
|
|
||||||
print(f"Attempt {elapsed + 1}: Unable to connect to Redis master at {leader_ip}:{leader_port}. Retrying in {wait_interval} second(s)...")
|
|
||||||
time.sleep(wait_interval)
|
|
||||||
elapsed += wait_interval
|
|
||||||
print(f"Failed to connect to Redis master at {leader_ip}:{leader_port} after {max_wait} seconds. Exiting.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
def main():
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Set up Redis Sentinel-based worker queue.')
|
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||||
parser.add_argument('--leader-ip', help='IP address of the initial leader node')
|
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
|
||||||
parser.add_argument('--leader-port', type=int, default=6379, help='Port of the initial leader node')
|
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
|
||||||
parser.add_argument('--replica', type=int, required=True, help='Replica number (0 to N-1)')
|
parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024)
|
||||||
parser.add_argument('--add-pdfs', help='S3 glob path for work items')
|
parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000)
|
||||||
|
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
|
||||||
|
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
|
||||||
|
parser.add_argument('--group_size', type=int, default=20, help='Number of pdfs that will be part of each work item in the work queue.')
|
||||||
|
parser.add_argument('--workers', type=int, default=10, help='Number of workers to run at a time')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
replica_number = args.replica
|
if args.workspace_profile:
|
||||||
|
workspace_session = boto3.Session(profile_name=args.workspace_profile)
|
||||||
|
workspace_s3 = workspace_session.client("s3")
|
||||||
|
|
||||||
base_redis_port = 6379
|
if args.pdf_profile:
|
||||||
base_sentinel_port = 26379
|
pdf_session = boto3.Session(profile_name=args.pdf_profile)
|
||||||
|
pdf_s3 = pdf_session.client("s3")
|
||||||
|
|
||||||
redis_port = base_redis_port + replica_number
|
# Check list of pdfs and that it matches what's in the workspace
|
||||||
sentinel_port = base_sentinel_port + replica_number
|
if args.pdfs:
|
||||||
|
if args.pdfs.startswith("s3://"):
|
||||||
if replica_number == 0:
|
logger.info(f"Expanding s3 glob at {args.pdfs}")
|
||||||
leader_ip = args.leader_ip if args.leader_ip else '127.0.0.1'
|
all_pdfs = expand_s3_glob(pdf_s3, args.pdfs)
|
||||||
leader_port = args.leader_port
|
elif os.path.exists(args.pdfs):
|
||||||
else:
|
logger.info(f"Loading file at {args.pdfs}")
|
||||||
if not args.leader_ip:
|
with open(args.pdfs, "r") as f:
|
||||||
print('Error: --leader-ip is required for replica nodes (replica_number >= 1)')
|
all_pdfs = list(filter(None, (line.strip() for line in tqdm(f, desc="Processing PDFs"))))
|
||||||
sys.exit(1)
|
|
||||||
leader_ip = args.leader_ip
|
|
||||||
leader_port = args.leader_port
|
|
||||||
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
redis_conf_path = os.path.join(temp_dir, 'redis.conf')
|
|
||||||
sentinel_conf_path = os.path.join(temp_dir, 'sentinel.conf')
|
|
||||||
|
|
||||||
print("Redis config path:", redis_conf_path)
|
|
||||||
|
|
||||||
with open(redis_conf_path, 'w') as f:
|
|
||||||
f.write(f'port {redis_port}\n')
|
|
||||||
f.write(f'dbfilename dump-{replica_number}.rdb\n')
|
|
||||||
f.write(f'appendfilename "appendonly-{replica_number}.aof"\n')
|
|
||||||
f.write(f'logfile "redis-{replica_number}.log"\n')
|
|
||||||
f.write(f'dir {temp_dir}\n')
|
|
||||||
if replica_number == 0:
|
|
||||||
f.write('bind 0.0.0.0\n')
|
|
||||||
else:
|
else:
|
||||||
f.write(f'replicaof {leader_ip} {leader_port}\n')
|
raise ValueError("pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
||||||
|
|
||||||
master_name = 'mymaster'
|
all_pdfs = set(all_pdfs)
|
||||||
quorum = 1
|
logger.info(f"Found {len(all_pdfs):,} total pdf paths")
|
||||||
|
|
||||||
|
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
||||||
|
existing_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
||||||
|
|
||||||
with open(sentinel_conf_path, 'w') as f:
|
# Parse existing work items into groups
|
||||||
f.write(f'port {sentinel_port}\n')
|
existing_groups = [line.strip().split(",") for line in existing_lines if line.strip()]
|
||||||
f.write(f'dir {temp_dir}\n')
|
existing_pdf_set = set(pdf for group in existing_groups for pdf in group)
|
||||||
f.write(f'sentinel monitor {master_name} {leader_ip} {leader_port} {quorum}\n')
|
|
||||||
f.write(f'sentinel down-after-milliseconds {master_name} 5000\n')
|
|
||||||
f.write(f'sentinel failover-timeout {master_name} 10000\n')
|
|
||||||
f.write(f'sentinel parallel-syncs {master_name} 1\n')
|
|
||||||
|
|
||||||
redis_process = subprocess.Popen(['redis-server', redis_conf_path])
|
logger.info(f"Loaded {len(existing_pdf_set):,} existing pdf paths from the workspace")
|
||||||
sentinel_process = subprocess.Popen(['redis-sentinel', sentinel_conf_path])
|
|
||||||
|
|
||||||
# Register atexit function to guarantee process termination
|
# Remove existing PDFs from all_pdfs
|
||||||
def terminate_processes():
|
new_pdfs = all_pdfs - existing_pdf_set
|
||||||
print("Terminating child processes...")
|
logger.info(f"{len(new_pdfs):,} new pdf paths to add to the workspace")
|
||||||
redis_process.terminate()
|
|
||||||
sentinel_process.terminate()
|
|
||||||
try:
|
|
||||||
redis_process.wait(timeout=5)
|
|
||||||
sentinel_process.wait(timeout=5)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
print("Forcing termination of child processes.")
|
|
||||||
redis_process.kill()
|
|
||||||
sentinel_process.kill()
|
|
||||||
print("Child processes terminated.")
|
|
||||||
|
|
||||||
atexit.register(terminate_processes)
|
# Group the new PDFs into chunks of group_size
|
||||||
|
new_groups = []
|
||||||
|
current_group = []
|
||||||
|
for pdf in sorted(new_pdfs): # Sort for consistency
|
||||||
|
current_group.append(pdf)
|
||||||
|
if len(current_group) == args.group_size:
|
||||||
|
new_groups.append(current_group)
|
||||||
|
current_group = []
|
||||||
|
if current_group:
|
||||||
|
new_groups.append(current_group)
|
||||||
|
|
||||||
# Also handle signal-based termination
|
logger.info(f"Created {len(new_groups):,} new work groups")
|
||||||
def handle_signal(signum, frame):
|
|
||||||
print(f"Received signal {signum}. Terminating processes...")
|
|
||||||
terminate_processes()
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, handle_signal)
|
# Combine existing groups with new groups
|
||||||
signal.signal(signal.SIGTERM, handle_signal)
|
combined_groups = existing_groups + new_groups
|
||||||
|
|
||||||
time.sleep(2)
|
# Prepare lines to write back
|
||||||
|
combined_lines = [",".join(group) for group in combined_groups]
|
||||||
|
|
||||||
# Use Sentinel to connect to the master
|
# Upload the combined work items back to S3
|
||||||
from redis.sentinel import Sentinel
|
upload_zstd_csv(workspace_s3, index_file_s3_path, combined_lines)
|
||||||
sentinel = Sentinel([('127.0.0.1', sentinel_port)], socket_timeout=0.1)
|
|
||||||
|
|
||||||
# Initial connection to Redis master
|
logger.info("Completed adding new PDFs.")
|
||||||
redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port)
|
|
||||||
|
|
||||||
# Populate the work queue if it's empty, using a distributed lock
|
|
||||||
populate_queue_if_empty(redis_client, args.add_pdfs, redis_client)
|
|
||||||
|
|
||||||
try:
|
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
|
||||||
while True:
|
# If not, then your job is to do the actual work
|
||||||
try:
|
|
||||||
# Try to get an item from the queue with a 1-minute timeout for processing
|
|
||||||
work_item = redis_client.brpoplpush("work_queue", "processing_queue", 60)
|
|
||||||
if work_item:
|
|
||||||
try:
|
|
||||||
process(work_item)
|
|
||||||
# Remove from the processing queue if processed successfully
|
|
||||||
redis_client.lrem("processing_queue", 1, work_item)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {work_item}: {e}")
|
|
||||||
# If an error occurs, let it be requeued after timeout
|
|
||||||
|
|
||||||
queue_length = redis_client.llen("work_queue")
|
# Start up the sglang server
|
||||||
print(f"Total work items in queue: {queue_length}")
|
|
||||||
|
|
||||||
time.sleep(0.1)
|
# Read in the work queue from s3
|
||||||
|
# Read in the done items from the s3 workspace
|
||||||
|
|
||||||
except (redis.exceptions.ConnectionError, redis.exceptions.TimeoutError) as e:
|
# Spawn up to N workers to do:
|
||||||
print("Lost connection to Redis. Attempting to reconnect using Sentinel...")
|
# In a loop, take a random work item, read in the pdfs, queue in their requests
|
||||||
# Attempt to reconnect using Sentinel
|
# Get results back, retry any failed pages
|
||||||
while True:
|
# Check periodically if that work is done in s3, if so, then abandon this work
|
||||||
try:
|
# Save results back to s3 workspace output folder
|
||||||
redis_client = get_redis_client(sentinel, master_name, leader_ip, leader_port)
|
|
||||||
print("Reconnected to Redis master.")
|
|
||||||
break # Exit the reconnection loop and resume work
|
|
||||||
except redis.exceptions.ConnectionError:
|
|
||||||
print("Reconnection failed. Retrying in 5 seconds...")
|
|
||||||
time.sleep(5)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Unexpected error: {e}")
|
|
||||||
handle_signal(None, None)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
# Possible future addon, in beaker, discover other nodes on this same job
|
||||||
handle_signal(None, None)
|
# Send them a message when you take a work item off the queue
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user