From 27d23525b79e959cc0e3d9e08b2a2a976a6fb697 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 19 Nov 2024 10:41:58 -0800 Subject: [PATCH] Claude recommends httpx instead of aiohttp, seeing if that will help with straggler timeouts --- pdelfin/beakerpipeline.py | 44 +++++++++++++++++++-------------------- pdelfin/version.py | 2 +- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index 81ce75f..78a9110 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -11,7 +11,7 @@ import json import base64 import atexit import asyncio -import aiohttp +import httpx import datetime import tempfile import random @@ -124,7 +124,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_ } -async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult: +async def process_page(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult: COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions" MAX_RETRIES = args.max_page_retries @@ -144,17 +144,15 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf ) try: - async with session.post(COMPLETION_URL, json=query) as response: - if response.status == 400: - error_text = await response.text() - raise ValueError(f"Got BadRequestError from server: {error_text}, skipping this response") - elif response.status == 500: - error_text = await response.text() - raise ValueError(f"Got InternalServerError from server: {error_text}, skipping this response") - else: - response.raise_for_status() + response = await session.post(COMPLETION_URL, json=query) + if response.status_code == 400: + raise ValueError(f"Got BadRequestError from server: {response.text}, skipping this response") + elif response.status_code == 500: + raise ValueError(f"Got InternalServerError from server: {response.text}, skipping this response") + else: + response.raise_for_status() - base_response_data = await response.json() + base_response_data = response.json() if base_response_data["usage"]["total_tokens"] > args.model_max_context: local_anchor_text_len = max(1, local_anchor_text_len // 2) @@ -180,7 +178,7 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf input_tokens=base_response_data["usage"].get("prompt_tokens", 0), output_tokens=base_response_data["usage"].get("completion_tokens", 0) ) - except (aiohttp.ClientError, asyncio.TimeoutError) as e: + except (httpx.TimeoutException, asyncio.TimeoutError) as e: logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}") # Now we want to do exponential backoff, and not count this as an actual page retry @@ -209,7 +207,7 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf raise ValueError(f"Could not process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts") -async def process_pdf(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str): +async def process_pdf(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_path: str): with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: # TODO Switch to aioboto3 or something data = await asyncio.to_thread(lambda: get_s3_bytes_with_backoff(pdf_s3, pdf_s3_path)) @@ -306,8 +304,7 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id): await tracker.clear_work(worker_id) try: - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600), - connector=aiohttp.TCPConnector(limit=1000)) as session: + async with httpx.AsyncClient(timeout=600, limits=httpx.Limits(max_connections=1000)) as session: async with asyncio.TaskGroup() as tg: dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in work_item.s3_work_paths] logger.info(f"Created all tasks for {work_item.hash}") @@ -466,13 +463,14 @@ async def sglang_server_ready(): for attempt in range(1, max_attempts + 1): try: - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - logger.info("sglang server is ready.") - return - else: - logger.info(f"Attempt {attempt}: Unexpected status code {response.status}") + async with httpx.AsyncClient() as session: + response = await session.get(url) + + if response.status_code == 200: + logger.info("sglang server is ready.") + return + else: + logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}") except Exception as e: logger.warning(f"Attempt {attempt}: {e}") diff --git a/pdelfin/version.py b/pdelfin/version.py index 0e7137a..1a72901 100644 --- a/pdelfin/version.py +++ b/pdelfin/version.py @@ -2,7 +2,7 @@ _MAJOR = "0" _MINOR = "1" # On main and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "33" +_PATCH = "34" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = ""