diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index a2a7466..849e0bc 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -236,8 +236,10 @@ async def load_pdf_work_queue(args) -> asyncio.Queue: async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult: COMPLETION_URL = "http://localhost:30000/v1/chat/completions" MAX_RETRIES = 3 + + attempt = 0 - for attempt in range(1, MAX_RETRIES + 1): + while attempt < MAX_RETRIES: query = await build_page_query( pdf_local_path, page_num, @@ -267,11 +269,20 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p output_tokens=base_response_data["usage"].get("completion_tokens", 0) ) except aiohttp.ClientError as e: - logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}:: {e}") + logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}") + + # Now we want to do exponential backoff, and not count this as an actual page retry + # Page retrys are supposed to be for fixing bad results from the model, but actual requests to sglang + # are supposed to work. Probably this means that the server is just restarting + logger.info(f"Sleeping for 5 seconds on {pdf_s3_path}-{page_num} to allow server restart") + await asyncio.sleep(5) + except json.JSONDecodeError as e: logger.warning(f"JSON decode error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}") + attempt += 1 except Exception as e: - logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}:: {e}") + logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}") + attempt += 1 if attempt >= MAX_RETRIES: logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.") @@ -429,25 +440,41 @@ async def sglang_server_task(args, semaphore): atexit.register(_kill_proc) - last_queue_req = None # To track transitions + last_running_req, last_queue_req = 0, 0 # To track transitions + can_release_automatically = False + last_semaphore_release = time.time() async def process_line(line): sglang_logger.info(line) + + match = re.search(r'#running-req: (\d+)', line) + if match: + last_running_req = int(match.group(1)) + + if last_running_req > 0: + can_release_automatically = True # Parse the line and update semaphore if necessary match = re.search(r'#queue-req: (\d+)', line) if match: queue_req = int(match.group(1)) - logger.info(f"sglang queue req: {queue_req}") + logger.info(f"sglang running req: {last_running_req} queue req: {queue_req}") nonlocal last_queue_req - if last_queue_req is not None and last_queue_req != 0 and queue_req == 0: + if last_queue_req != 0 and queue_req == 0: # Release the semaphore when queue_req transitions from non-zero to zero if semaphore.locked(): semaphore.release() + last_semaphore_release = time.time() logger.info("Semaphore released, allowing a worker to proceed.") last_queue_req = queue_req + # And have a semaphore release automatically if there are no running requests for > 30 seconds + if last_running_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked(): + semaphore.release() + last_semaphore_release = time.time() + logger.info("Semaphore released due to timeout, allowing a worker to proceed.") + async def read_stream(stream): while True: line = await stream.readline() @@ -465,6 +492,11 @@ async def sglang_server_task(args, semaphore): await stderr_task +async def sglang_server_host(args, semaphore): + while True: + await sglang_server_task(args, semaphore) + + async def sglang_server_ready(): max_attempts = 300 delay_sec = 1 @@ -528,7 +560,7 @@ async def main(): # As soon as one worker is no longer saturating the gpu, the next one can start sending requests semaphore = asyncio.Semaphore(1) - sglang_server = asyncio.create_task(sglang_server_task(args, semaphore)) + sglang_server = asyncio.create_task(sglang_server_host(args, semaphore)) work_queue = await load_pdf_work_queue(args) logger.info(f"Work queue prepared with {work_queue.qsize()} items") diff --git a/pyproject.toml b/pyproject.toml index 4d8bbf6..9d9269c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,7 @@ dev = [ ] inference = [ - "sglang[all]" + { git = "https://github.com/sgl-project/sglang.git", rev = "eff468dd5a3d24646560eb044276585f7a11ac3c", subdirectory = "python", extras = ["all"] } ] train = [