From ee72b3601e2f7b62867861fc8b2d580edb9d5890 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Fri, 8 Nov 2024 09:14:00 -0800 Subject: [PATCH] Starting up server and workers async now --- pdelfin/beakerpipeline.py | 104 +++++++++++++++++++++++--------------- pdelfin/s3_utils.py | 75 ++++++++++++++++++++++++--- 2 files changed, 133 insertions(+), 46 deletions(-) diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index 5673c36..b81749c 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -74,17 +74,6 @@ def compute_workgroup_sha1(work_group: list[str]) -> str: sha1.update(pdf.encode('utf-8')) return sha1.hexdigest() -async def start_sglang_server(args): - model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model') - download_directory(args.model, model_cache_dir) - - # Start up the sglang server - sglang_process = subprocess.Popen([ - "python3", "-m", "sglang.launch_server", - "--model-path", model_cache_dir, - "--chat-template", args.model_chat_template, - "--context-length", str(args.model_max_context), - ]) async def populate_pdf_work_queue(args): index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") @@ -152,44 +141,73 @@ async def populate_pdf_work_queue(args): async def load_pdf_work_queue(args) -> asyncio.Queue: index_file_s3_path = os.path.join(args.workspace, "pdf_index_list.csv.zstd") - - # Read in the work queue from s3 - work_queue_lines = download_zstd_csv(workspace_s3, index_file_s3_path) - work_queue = {} - for line in work_queue_lines: - if line.strip(): - parts = line.strip().split(",") - group_hash = parts[0] - group_pdfs = parts[1:] - work_queue[group_hash] = group_pdfs + output_glob = f"{args.workspace}/dolma_documents/output_*.jsonl" - # Read in the done items from the s3 workspace - done_work_items = expand_s3_glob(workspace_s3, f"{args.workspace}/dolma_documents/output_*.jsonl") - done_work_hashes = set() - for item in done_work_items: - filename = os.path.basename(item) - if filename.startswith('output_') and filename.endswith('.jsonl'): - group_hash = filename[len('output_'):-len('.jsonl')] - done_work_hashes.add(group_hash) + # Define the two blocking I/O operations + download_task = asyncio.to_thread(download_zstd_csv, workspace_s3, index_file_s3_path) + expand_task = asyncio.to_thread(expand_s3_glob, workspace_s3, output_glob) - remaining_work_hashes = set(work_queue.keys()) - done_work_hashes - remaining_work_queue = {hash: work_queue[hash] for hash in remaining_work_hashes} + # Run both tasks concurrently + work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task) + # Process the work queue lines + work_queue = { + parts[0]: parts[1:] + for line in work_queue_lines + if (parts := line.strip().split(",")) and line.strip() + } + + # Extract done work hashes + done_work_hashes = { + os.path.basename(item)[len('output_'):-len('.jsonl')] + for item in done_work_items + if os.path.basename(item).startswith('output_') and os.path.basename(item).endswith('.jsonl') + } + + # Determine remaining work + remaining_work_hashes = set(work_queue) - done_work_hashes + remaining_work_queue = { + hash_: work_queue[hash_] + for hash_ in remaining_work_hashes + } + + # Populate the asyncio.Queue with remaining work queue = asyncio.Queue() - - for work in remaining_work_queue: - await queue.put((work, remaining_work_queue[work])) + for work, pdfs in remaining_work_queue.items(): + await queue.put((work, pdfs)) return queue +async def process_pdf(args, pdf_s3_path): + await asyncio.sleep(1) + return f"pdf: {pdf_s3_path}" + async def worker(args, queue): while True: - work = await queue.get() + [work_hash, pdfs] = await queue.get() + + completed_pdfs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs]) + logger.info(f"Completed {completed_pdfs}") - logger.info(f"Got work to do for {work}") queue.task_done() +async def sglang_server_task(args): + model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model') + #download_directory(args.model, model_cache_dir) + + proc = await asyncio.create_subprocess_exec( + "python3", + + "-m", "sglang.launch_server", + "--model-path", model_cache_dir, + "--chat-template", args.model_chat_template, + "--context-length", str(args.model_max_context), + ) + + await proc.wait() + + async def main(): parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/') @@ -222,25 +240,31 @@ async def main(): if args.pdfs: await populate_pdf_work_queue(args) + sglang_server = asyncio.create_task(sglang_server_task(args)) + work_queue = await load_pdf_work_queue(args) logger.info(f"Work queue prepared with {work_queue.qsize()} items") # Create worker tasks to process the queue concurrently. - tasks = [] + worker_tasks = [] for i in range(args.workers): task = asyncio.create_task(worker(args, work_queue)) - tasks.append(task) + worker_tasks.append(task) # Wait for the queue to be fully processed await work_queue.join() # Cancel our worker tasks. - for task in tasks: + for task in worker_tasks: task.cancel() # Wait until all worker tasks are cancelled. - await asyncio.gather(*tasks, return_exceptions=True) + await asyncio.gather(*worker_tasks, return_exceptions=True) + + # Wait for server to stop + sglang_server.cancel() + await sglang_server if __name__ == "__main__": diff --git a/pdelfin/s3_utils.py b/pdelfin/s3_utils.py index ad00793..be25774 100644 --- a/pdelfin/s3_utils.py +++ b/pdelfin/s3_utils.py @@ -24,8 +24,8 @@ logging.basicConfig(level=logging.INFO) def parse_s3_path(s3_path: str) -> tuple[str, str]: - if not s3_path.startswith('s3://'): - raise ValueError('s3_path must start with s3://') + if not (s3_path.startswith('s3://') or s3_path.startswith('gs://') or s3_path.startswith('weka://')): + raise ValueError('s3_path must start with s3://, gs://, or weka://') parsed = urlparse(s3_path) bucket = parsed.netloc key = parsed.path.lstrip('/') @@ -137,10 +137,10 @@ def download_directory(model_choices: list[str], local_dir: str): """ Download the model to a specified local directory. The function will attempt to download from the first available source in the provided list. - Supports Google Cloud Storage (gs://) and Amazon S3 (s3://) links. + Supports Weka (weka://), Google Cloud Storage (gs://), and Amazon S3 (s3://) links. Args: - model_choices (list[str]): List of model paths (gs:// or s3://). + model_choices (list[str]): List of model paths (weka://, gs://, or s3://). local_dir (str): Local directory path where the model will be downloaded. Raises: @@ -151,11 +151,20 @@ def download_directory(model_choices: list[str], local_dir: str): local_path.mkdir(parents=True, exist_ok=True) logger.info(f"Local directory set to: {local_path}") + # Reorder model_choices to prioritize weka:// links + weka_choices = [path for path in model_choices if path.startswith("weka://")] + other_choices = [path for path in model_choices if not path.startswith("weka://")] + prioritized_choices = weka_choices + other_choices + # Iterate through the provided choices and attempt to download from the first available source - for model_path in model_choices: + for model_path in prioritized_choices: logger.info(f"Attempting to download from: {model_path}") try: - if model_path.startswith("gs://"): + if model_path.startswith("weka://"): + download_dir_from_weka(model_path, str(local_path)) + logger.info(f"Successfully downloaded model from Weka: {model_path}") + return + elif model_path.startswith("gs://"): download_dir_from_gcs(model_path, str(local_path)) logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}") return @@ -229,3 +238,57 @@ def download_dir_from_s3(s3_path: str, local_dir: str): pass logger.info(f"Downloaded model from S3 to {local_dir}") + + +def download_dir_from_weka(weka_path: str, local_dir: str): + """Download model files from Weka to a local directory.""" + # Retrieve Weka credentials from environment variables + weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID") + weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY") + if not weka_access_key or not weka_secret_key: + raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY environment variables must be set for Weka access.") + + # Configure the boto3 client for Weka + weka_endpoint = "https://weka-aus.beaker.org:9000" + boto3_config = Config( + max_pool_connections=50, # Adjust this number based on your requirements + signature_version='s3v4', + retries={'max_attempts': 10, 'mode': 'standard'} + ) + s3_client = boto3.client( + 's3', + endpoint_url=weka_endpoint, + aws_access_key_id=weka_access_key, + aws_secret_access_key=weka_secret_key, + config=boto3_config + ) + + bucket, prefix = parse_s3_path(weka_path) + paginator = s3_client.get_paginator("list_objects_v2") + try: + pages = paginator.paginate(Bucket=bucket, Prefix=prefix) + except s3_client.exceptions.NoSuchBucket: + raise ValueError(f"The bucket '{bucket}' does not exist in Weka.") + + objects = [] + for page in pages: + if 'Contents' in page: + objects.extend(page['Contents']) + + total_files = len(objects) + logger.info(f"Found {total_files} files in Weka bucket '{bucket}' with prefix '{prefix}'.") + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + for obj in objects: + key = obj["Key"] + relative_path = os.path.relpath(key, prefix) + local_file_path = os.path.join(local_dir, relative_path) + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) + futures.append(executor.submit(s3_client.download_file, bucket, key, local_file_path)) + + # Use tqdm to display progress + for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc="Downloading from Weka"): + pass + + logger.info(f"Downloaded model from Weka to {local_dir}")