mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-01 18:43:45 +00:00
Starting up server and workers async now
This commit is contained in:
parent
a39350e074
commit
ee72b3601e
@ -74,17 +74,6 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
|
||||
sha1.update(pdf.encode('utf-8'))
|
||||
return sha1.hexdigest()
|
||||
|
||||
async def start_sglang_server(args):
|
||||
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
||||
download_directory(args.model, model_cache_dir)
|
||||
|
||||
# Start up the sglang server
|
||||
sglang_process = subprocess.Popen([
|
||||
"python3", "-m", "sglang.launch_server",
|
||||
"--model-path", model_cache_dir,
|
||||
"--chat-template", args.model_chat_template,
|
||||
"--context-length", str(args.model_max_context),
|
||||
])
|
||||
|
||||
async def populate_pdf_work_queue(args):
|
||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
||||
@ -152,44 +141,73 @@ async def populate_pdf_work_queue(args):
|
||||
|
||||
async def load_pdf_work_queue(args) -> asyncio.Queue:
|
||||
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
|
||||
|
||||
# Read in the work queue from s3
|
||||
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
|
||||
work_queue = {}
|
||||
for line in work_queue_lines:
|
||||
if line.strip():
|
||||
parts = line.strip().split(",")
|
||||
group_hash = parts[0]
|
||||
group_pdfs = parts[1:]
|
||||
work_queue[group_hash] = group_pdfs
|
||||
output_glob = f"{args.workspace}/dolma_documents/output_*.jsonl"
|
||||
|
||||
# Read in the done items from the s3 workspace
|
||||
done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl")
|
||||
done_work_hashes = set()
|
||||
for item in done_work_items:
|
||||
filename = os.path.basename(item)
|
||||
if filename.startswith('output_') and filename.endswith('.jsonl'):
|
||||
group_hash = filename[len('output_'):-len('.jsonl')]
|
||||
done_work_hashes.add(group_hash)
|
||||
# Define the two blocking I/O operations
|
||||
download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path)
|
||||
expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob)
|
||||
|
||||
remaining_work_hashes = set(work_queue.keys()) - done_work_hashes
|
||||
remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes}
|
||||
# Run both tasks concurrently
|
||||
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
|
||||
|
||||
# Process the work queue lines
|
||||
work_queue = {
|
||||
parts[0]: parts[1:]
|
||||
for line in work_queue_lines
|
||||
if (parts := line.strip().split(",")) and line.strip()
|
||||
}
|
||||
|
||||
# Extract done work hashes
|
||||
done_work_hashes = {
|
||||
os.path.basename(item)[len('output_'):-len('.jsonl')]
|
||||
for item in done_work_items
|
||||
if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl')
|
||||
}
|
||||
|
||||
# Determine remaining work
|
||||
remaining_work_hashes = set(work_queue) - done_work_hashes
|
||||
remaining_work_queue = {
|
||||
hash_: work_queue[hash_]
|
||||
for hash_ in remaining_work_hashes
|
||||
}
|
||||
|
||||
# Populate the asyncio.Queue with remaining work
|
||||
queue = asyncio.Queue()
|
||||
|
||||
for work in remaining_work_queue:
|
||||
await queue.put((work, remaining_work_queue[work]))
|
||||
for work, pdfs in remaining_work_queue.items():
|
||||
await queue.put((work, pdfs))
|
||||
|
||||
return queue
|
||||
|
||||
async def process_pdf(args, pdf_s3_path):
|
||||
await asyncio.sleep(1)
|
||||
return f"pdf: {pdf_s3_path}"
|
||||
|
||||
async def worker(args, queue):
|
||||
while True:
|
||||
work = await queue.get()
|
||||
[work_hash, pdfs] = await queue.get()
|
||||
|
||||
completed_pdfs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
|
||||
logger.info(f"Completed {completed_pdfs}")
|
||||
|
||||
logger.info(f"Got work to do for {work}")
|
||||
queue.task_done()
|
||||
|
||||
|
||||
async def sglang_server_task(args):
|
||||
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
|
||||
#download_directory(args.model, model_cache_dir)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"python3",
|
||||
|
||||
"-m", "sglang.launch_server",
|
||||
"--model-path", model_cache_dir,
|
||||
"--chat-template", args.model_chat_template,
|
||||
"--context-length", str(args.model_max_context),
|
||||
)
|
||||
|
||||
await proc.wait()
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
|
||||
@ -222,25 +240,31 @@ async def main():
|
||||
if args.pdfs:
|
||||
await populate_pdf_work_queue(args)
|
||||
|
||||
sglang_server = asyncio.create_task(sglang_server_task(args))
|
||||
|
||||
work_queue = await load_pdf_work_queue(args)
|
||||
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
|
||||
|
||||
|
||||
# Create worker tasks to process the queue concurrently.
|
||||
tasks = []
|
||||
worker_tasks = []
|
||||
for i in range(args.workers):
|
||||
task = asyncio.create_task(worker(args, work_queue))
|
||||
tasks.append(task)
|
||||
worker_tasks.append(task)
|
||||
|
||||
# Wait for the queue to be fully processed
|
||||
await work_queue.join()
|
||||
|
||||
# Cancel our worker tasks.
|
||||
for task in tasks:
|
||||
for task in worker_tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait until all worker tasks are cancelled.
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
await asyncio.gather(*worker_tasks, return_exceptions=True)
|
||||
|
||||
# Wait for server to stop
|
||||
sglang_server.cancel()
|
||||
await sglang_server
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -24,8 +24,8 @@ logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def parse_s3_path(s3_path: str) -> tuple[str, str]:
|
||||
if not s3_path.startswith('s3://'):
|
||||
raise ValueError('s3_path must start with s3://')
|
||||
if not (s3_path.startswith('s3://') or s3_path.startswith('gs://') or s3_path.startswith('weka://')):
|
||||
raise ValueError('s3_path must start with s3://, gs://, or weka://')
|
||||
parsed = urlparse(s3_path)
|
||||
bucket = parsed.netloc
|
||||
key = parsed.path.lstrip('/')
|
||||
@ -137,10 +137,10 @@ def download_directory(model_choices: list[str], local_dir: str):
|
||||
"""
|
||||
Download the model to a specified local directory.
|
||||
The function will attempt to download from the first available source in the provided list.
|
||||
Supports Google Cloud Storage (gs://) and Amazon S3 (s3://) links.
|
||||
Supports Weka (weka://), Google Cloud Storage (gs://), and Amazon S3 (s3://) links.
|
||||
|
||||
Args:
|
||||
model_choices (list[str]): List of model paths (gs:// or s3://).
|
||||
model_choices (list[str]): List of model paths (weka://, gs://, or s3://).
|
||||
local_dir (str): Local directory path where the model will be downloaded.
|
||||
|
||||
Raises:
|
||||
@ -151,11 +151,20 @@ def download_directory(model_choices: list[str], local_dir: str):
|
||||
local_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Local directory set to: {local_path}")
|
||||
|
||||
# Reorder model_choices to prioritize weka:// links
|
||||
weka_choices = [path for path in model_choices if path.startswith("weka://")]
|
||||
other_choices = [path for path in model_choices if not path.startswith("weka://")]
|
||||
prioritized_choices = weka_choices + other_choices
|
||||
|
||||
# Iterate through the provided choices and attempt to download from the first available source
|
||||
for model_path in model_choices:
|
||||
for model_path in prioritized_choices:
|
||||
logger.info(f"Attempting to download from: {model_path}")
|
||||
try:
|
||||
if model_path.startswith("gs://"):
|
||||
if model_path.startswith("weka://"):
|
||||
download_dir_from_weka(model_path, str(local_path))
|
||||
logger.info(f"Successfully downloaded model from Weka: {model_path}")
|
||||
return
|
||||
elif model_path.startswith("gs://"):
|
||||
download_dir_from_gcs(model_path, str(local_path))
|
||||
logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}")
|
||||
return
|
||||
@ -229,3 +238,57 @@ def download_dir_from_s3(s3_path: str, local_dir: str):
|
||||
pass
|
||||
|
||||
logger.info(f"Downloaded model from S3 to {local_dir}")
|
||||
|
||||
|
||||
def download_dir_from_weka(weka_path: str, local_dir: str):
|
||||
"""Download model files from Weka to a local directory."""
|
||||
# Retrieve Weka credentials from environment variables
|
||||
weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID")
|
||||
weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY")
|
||||
if not weka_access_key or not weka_secret_key:
|
||||
raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY environment variables must be set for Weka access.")
|
||||
|
||||
# Configure the boto3 client for Weka
|
||||
weka_endpoint = "https://weka-aus.beaker.org:9000"
|
||||
boto3_config = Config(
|
||||
max_pool_connections=50, # Adjust this number based on your requirements
|
||||
signature_version='s3v4',
|
||||
retries={'max_attempts': 10, 'mode': 'standard'}
|
||||
)
|
||||
s3_client = boto3.client(
|
||||
's3',
|
||||
endpoint_url=weka_endpoint,
|
||||
aws_access_key_id=weka_access_key,
|
||||
aws_secret_access_key=weka_secret_key,
|
||||
config=boto3_config
|
||||
)
|
||||
|
||||
bucket, prefix = parse_s3_path(weka_path)
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
try:
|
||||
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
||||
except s3_client.exceptions.NoSuchBucket:
|
||||
raise ValueError(f"The bucket '{bucket}' does not exist in Weka.")
|
||||
|
||||
objects = []
|
||||
for page in pages:
|
||||
if 'Contents' in page:
|
||||
objects.extend(page['Contents'])
|
||||
|
||||
total_files = len(objects)
|
||||
logger.info(f"Found {total_files} files in Weka bucket '{bucket}' with prefix '{prefix}'.")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for obj in objects:
|
||||
key = obj["Key"]
|
||||
relative_path = os.path.relpath(key, prefix)
|
||||
local_file_path = os.path.join(local_dir, relative_path)
|
||||
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
|
||||
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path))
|
||||
|
||||
# Use tqdm to display progress
|
||||
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from Weka"):
|
||||
pass
|
||||
|
||||
logger.info(f"Downloaded model from Weka to {local_dir}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user