Starting up server and workers async now

This commit is contained in:
Jake Poznanski 2024-11-08 09:14:00 -08:00
parent a39350e074
commit ee72b3601e
2 changed files with 133 additions and 46 deletions

View File

@ -74,17 +74,6 @@ def compute_workgroup_sha1(work_group: list[str]) -> str:
sha1.update(pdf.encode('utf-8'))
return sha1.hexdigest()
async def start_sglang_server(args):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
download_directory(args.model, model_cache_dir)
# Start up the sglang server
sglang_process = subprocess.Popen([
"python3", "-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
])
async def populate_pdf_work_queue(args):
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
@ -152,44 +141,73 @@ async def populate_pdf_work_queue(args):
async def load_pdf_work_queue(args) -> asyncio.Queue:
index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd")
# Read in the work queue from s3
work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path)
work_queue = {}
for line in work_queue_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_pdfs = parts[1:]
work_queue[group_hash] = group_pdfs
output_glob = f"{args.workspace}/dolma_documents/output_*.jsonl"
# Read in the done items from the s3 workspace
done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl")
done_work_hashes = set()
for item in done_work_items:
filename = os.path.basename(item)
if filename.startswith('output_') and filename.endswith('.jsonl'):
group_hash = filename[len('output_'):-len('.jsonl')]
done_work_hashes.add(group_hash)
# Define the two blocking I/O operations
download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path)
expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob)
remaining_work_hashes = set(work_queue.keys()) - done_work_hashes
remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes}
# Run both tasks concurrently
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
# Process the work queue lines
work_queue = {
parts[0]: parts[1:]
for line in work_queue_lines
if (parts := line.strip().split(",")) and line.strip()
}
# Extract done work hashes
done_work_hashes = {
os.path.basename(item)[len('output_'):-len('.jsonl')]
for item in done_work_items
if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl')
}
# Determine remaining work
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_work_queue = {
hash_: work_queue[hash_]
for hash_ in remaining_work_hashes
}
# Populate the asyncio.Queue with remaining work
queue = asyncio.Queue()
for work in remaining_work_queue:
await queue.put((work, remaining_work_queue[work]))
for work, pdfs in remaining_work_queue.items():
await queue.put((work, pdfs))
return queue
async def process_pdf(args, pdf_s3_path):
await asyncio.sleep(1)
return f"pdf: {pdf_s3_path}"
async def worker(args, queue):
while True:
work = await queue.get()
[work_hash, pdfs] = await queue.get()
completed_pdfs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
logger.info(f"Completed {completed_pdfs}")
logger.info(f"Got work to do for {work}")
queue.task_done()
async def sglang_server_task(args):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
#download_directory(args.model, model_cache_dir)
proc = await asyncio.create_subprocess_exec(
"python3",
"-m", "sglang.launch_server",
"--model-path", model_cache_dir,
"--chat-template", args.model_chat_template,
"--context-length", str(args.model_max_context),
)
await proc.wait()
async def main():
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
@ -222,25 +240,31 @@ async def main():
if args.pdfs:
await populate_pdf_work_queue(args)
sglang_server = asyncio.create_task(sglang_server_task(args))
work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
# Create worker tasks to process the queue concurrently.
tasks = []
worker_tasks = []
for i in range(args.workers):
task = asyncio.create_task(worker(args, work_queue))
tasks.append(task)
worker_tasks.append(task)
# Wait for the queue to be fully processed
await work_queue.join()
# Cancel our worker tasks.
for task in tasks:
for task in worker_tasks:
task.cancel()
# Wait until all worker tasks are cancelled.
await asyncio.gather(*tasks, return_exceptions=True)
await asyncio.gather(*worker_tasks, return_exceptions=True)
# Wait for server to stop
sglang_server.cancel()
await sglang_server
if __name__ == "__main__":

View File

@ -24,8 +24,8 @@ logging.basicConfig(level=logging.INFO)
def parse_s3_path(s3_path: str) -> tuple[str, str]:
if not s3_path.startswith('s3://'):
raise ValueError('s3_path must start with s3://')
if not (s3_path.startswith('s3://') or s3_path.startswith('gs://') or s3_path.startswith('weka://')):
raise ValueError('s3_path must start with s3://, gs://, or weka://')
parsed = urlparse(s3_path)
bucket = parsed.netloc
key = parsed.path.lstrip('/')
@ -137,10 +137,10 @@ def download_directory(model_choices: list[str], local_dir: str):
"""
Download the model to a specified local directory.
The function will attempt to download from the first available source in the provided list.
Supports Google Cloud Storage (gs://) and Amazon S3 (s3://) links.
Supports Weka (weka://), Google Cloud Storage (gs://), and Amazon S3 (s3://) links.
Args:
model_choices (list[str]): List of model paths (gs:// or s3://).
model_choices (list[str]): List of model paths (weka://, gs://, or s3://).
local_dir (str): Local directory path where the model will be downloaded.
Raises:
@ -151,11 +151,20 @@ def download_directory(model_choices: list[str], local_dir: str):
local_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Local directory set to: {local_path}")
# Reorder model_choices to prioritize weka:// links
weka_choices = [path for path in model_choices if path.startswith("weka://")]
other_choices = [path for path in model_choices if not path.startswith("weka://")]
prioritized_choices = weka_choices + other_choices
# Iterate through the provided choices and attempt to download from the first available source
for model_path in model_choices:
for model_path in prioritized_choices:
logger.info(f"Attempting to download from: {model_path}")
try:
if model_path.startswith("gs://"):
if model_path.startswith("weka://"):
download_dir_from_weka(model_path, str(local_path))
logger.info(f"Successfully downloaded model from Weka: {model_path}")
return
elif model_path.startswith("gs://"):
download_dir_from_gcs(model_path, str(local_path))
logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}")
return
@ -229,3 +238,57 @@ def download_dir_from_s3(s3_path: str, local_dir: str):
pass
logger.info(f"Downloaded model from S3 to {local_dir}")
def download_dir_from_weka(weka_path: str, local_dir: str):
"""Download model files from Weka to a local directory."""
# Retrieve Weka credentials from environment variables
weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID")
weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY")
if not weka_access_key or not weka_secret_key:
raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY environment variables must be set for Weka access.")
# Configure the boto3 client for Weka
weka_endpoint = "https://weka-aus.beaker.org:9000"
boto3_config = Config(
max_pool_connections=50, # Adjust this number based on your requirements
signature_version='s3v4',
retries={'max_attempts': 10, 'mode': 'standard'}
)
s3_client = boto3.client(
's3',
endpoint_url=weka_endpoint,
aws_access_key_id=weka_access_key,
aws_secret_access_key=weka_secret_key,
config=boto3_config
)
bucket, prefix = parse_s3_path(weka_path)
paginator = s3_client.get_paginator("list_objects_v2")
try:
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
except s3_client.exceptions.NoSuchBucket:
raise ValueError(f"The bucket '{bucket}' does not exist in Weka.")
objects = []
for page in pages:
if 'Contents' in page:
objects.extend(page['Contents'])
total_files = len(objects)
logger.info(f"Found {total_files} files in Weka bucket '{bucket}' with prefix '{prefix}'.")
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for obj in objects:
key = obj["Key"]
relative_path = os.path.relpath(key, prefix)
local_file_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path))
# Use tqdm to display progress
for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from Weka"):
pass
logger.info(f"Downloaded model from Weka to {local_dir}")