From 995b1d15fccfdfaa83623fb8cf28f834fc8c0631 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Mon, 18 Nov 2024 09:55:45 -0800 Subject: [PATCH] Fixes, mocking out queue into separate file --- pdelfin/beakerpipeline.py | 8 +- pdelfin/s3_queue.py | 164 ++++++++++++++++++++++++++++++++++++++ pdelfin/version.py | 2 +- 3 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 pdelfin/s3_queue.py diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index be8122f..bd1fb70 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -550,16 +550,20 @@ async def sglang_server_task(args, semaphore): # Shared variables between tasks last_running_req, last_queue_req = 0, 0 + server_printed_ready_message = False last_semaphore_release = time.time() async def process_line(line): - nonlocal last_running_req, last_queue_req, last_semaphore_release + nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message sglang_logger.info(line) if "Detected errors during sampling" in line: logger.error("Cannot continue, sampling errors detected, model is probably corrupt") sys.exit(1) + if not server_printed_ready_message and "The server is fired up and ready to roll!" in line: + server_printed_ready_message = True + match = re.search(r'#running-req: (\d+)', line) if match: last_running_req = int(match.group(1)) @@ -582,7 +586,7 @@ async def sglang_server_task(args, semaphore): try: while True: await asyncio.sleep(1) - if last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked(): + if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked(): semaphore.release() last_semaphore_release = time.time() logger.info("Semaphore released, allowing a worker to proceed.") diff --git a/pdelfin/s3_queue.py b/pdelfin/s3_queue.py new file mode 100644 index 0000000..43a2fe9 --- /dev/null +++ b/pdelfin/s3_queue.py @@ -0,0 +1,164 @@ +import os +import random +import logging +import hashlib +import tempfile +from typing import Optional, Tuple, List, Dict, Set +from dataclasses import dataclass +import asyncio +from functools import partial + +from pdelfin.s3_utils import ( + expand_s3_glob, + download_zstd_csv, + upload_zstd_csv, + parse_s3_path +) +from pypdf import PdfReader + +logger = logging.getLogger(__name__) + +@dataclass +class WorkItem: + """Represents a single work item in the queue""" + hash: str + s3_work_paths: List[str] + +class S3WorkQueue: + """ + Manages a work queue stored in S3 that coordinates work across multiple workers. + The queue maintains a list of work items, where each work item is a group of s3 paths + that should be processed together. + + Each work item gets a hash, and completed work items will have their results + stored in s3://workspace_path/results/output_[hash].jsonl + + This is the ground source of truth about which work items are done. + + When a worker takes an item off the queue, it will write an empty s3 file to + s3://workspace_path/worker_locks/output_[hash].jsonl + + The queue gets randomized on each worker, so workers pull random work items to operate on. + As you pull an item, we will check to see if it has been completed. If yes, + then it will immediately fetch the next item. If a lock file was created within a configurable + timeout (30 mins by default), then that work item is also skipped. + + The lock will will be deleted once the worker is done with that item. + """ + def __init__(self, s3_client, workspace_path: str): + """ + Initialize the work queue. + + Args: + s3_client: Boto3 S3 client to use for operations + workspace_path: S3 path where work queue and results are stored + """ + self.s3_client = s3_client + self.workspace_path = workspace_path.rstrip('/') + + self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd") + self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl") + + @staticmethod + def _compute_workgroup_hash(s3_work_paths: List[str]) -> str: + """ + Compute a deterministic hash for a group of PDFs. + + Args: + pdfs: List of PDF S3 paths + + Returns: + SHA1 hash of the sorted PDF paths + """ + sha1 = hashlib.sha1() + for pdf in sorted(s3_work_paths): + sha1.update(pdf.encode('utf-8')) + return sha1.hexdigest() + + + async def populate_queue(self, s3_work_paths: str, items_per_group: int) -> None: + pass + + async def initialize_queue(self) -> None: + """ + Load the work queue from S3 and initialize it for processing. + Removes already completed work items and randomizes the order. + """ + # Load work items and completed items in parallel + download_task = asyncio.to_thread( + download_zstd_csv, + self.s3_client, + self._index_path + ) + expand_task = asyncio.to_thread( + expand_s3_glob, + self.s3_client, + self._output_glob + ) + + work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task) + + # Process work queue lines + work_queue = { + parts[0]: parts[1:] + for line in work_queue_lines + if (parts := line.strip().split(",")) and line.strip() + } + + # Get set of completed work hashes + done_work_hashes = { + os.path.basename(item)[len('output_'):-len('.jsonl')] + for item in done_work_items + if os.path.basename(item).startswith('output_') + and os.path.basename(item).endswith('.jsonl') + } + + # Find remaining work and shuffle + remaining_work_hashes = set(work_queue) - done_work_hashes + remaining_items = [ + WorkItem(hash_=hash_, pdfs=work_queue[hash_]) + for hash_ in remaining_work_hashes + ] + random.shuffle(remaining_items) + + # Initialize queue + self._queue = asyncio.Queue() + for item in remaining_items: + await self._queue.put(item) + + logger.info(f"Initialized queue with {self._queue.qsize()} work items") + + async def is_completed(self, work_hash: str) -> bool: + """ + Check if a work item has been completed. + + Args: + work_hash: Hash of the work item to check + + Returns: + True if the work is completed, False otherwise + """ + output_s3_path = ""TODO"" + bucket, key = parse_s3_path(output_s3_path) + + try: + await asyncio.to_thread( + self.s3_client.head_object, + Bucket=bucket, + Key=key + ) + return True + except self.s3_client.exceptions.ClientError: + return False + + async def get_work(self) -> Optional[WorkItem]: + pass + + def mark_done(self, work_item: WorkItem) -> None: + """Mark the most recently gotten work item as complete""" + pass + + @property + def size(self) -> int: + """Get current size of work queue""" + return self._queue.qsize() \ No newline at end of file diff --git a/pdelfin/version.py b/pdelfin/version.py index 8059018..abbeb42 100644 --- a/pdelfin/version.py +++ b/pdelfin/version.py @@ -2,7 +2,7 @@ _MAJOR = "0" _MINOR = "1" # On main and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "24" +_PATCH = "25" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = ""