mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-11 07:58:10 +00:00
Claude recommends httpx instead of aiohttp, seeing if that will help with straggler timeouts
This commit is contained in:
parent
4469f4b2ce
commit
27d23525b7
@ -11,7 +11,7 @@ import json
|
|||||||
import base64
|
import base64
|
||||||
import atexit
|
import atexit
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import httpx
|
||||||
import datetime
|
import datetime
|
||||||
import tempfile
|
import tempfile
|
||||||
import random
|
import random
|
||||||
@ -124,7 +124,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
async def process_page(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||||
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||||
MAX_RETRIES = args.max_page_retries
|
MAX_RETRIES = args.max_page_retries
|
||||||
|
|
||||||
@ -144,17 +144,15 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with session.post(COMPLETION_URL, json=query) as response:
|
response = await session.post(COMPLETION_URL, json=query)
|
||||||
if response.status == 400:
|
if response.status_code == 400:
|
||||||
error_text = await response.text()
|
raise ValueError(f"Got BadRequestError from server: {response.text}, skipping this response")
|
||||||
raise ValueError(f"Got BadRequestError from server: {error_text}, skipping this response")
|
elif response.status_code == 500:
|
||||||
elif response.status == 500:
|
raise ValueError(f"Got InternalServerError from server: {response.text}, skipping this response")
|
||||||
error_text = await response.text()
|
|
||||||
raise ValueError(f"Got InternalServerError from server: {error_text}, skipping this response")
|
|
||||||
else:
|
else:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
base_response_data = await response.json()
|
base_response_data = response.json()
|
||||||
|
|
||||||
if base_response_data["usage"]["total_tokens"] > args.model_max_context:
|
if base_response_data["usage"]["total_tokens"] > args.model_max_context:
|
||||||
local_anchor_text_len = max(1, local_anchor_text_len // 2)
|
local_anchor_text_len = max(1, local_anchor_text_len // 2)
|
||||||
@ -180,7 +178,7 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf
|
|||||||
input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
|
input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
|
||||||
output_tokens=base_response_data["usage"].get("completion_tokens", 0)
|
output_tokens=base_response_data["usage"].get("completion_tokens", 0)
|
||||||
)
|
)
|
||||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
except (httpx.TimeoutException, asyncio.TimeoutError) as e:
|
||||||
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
|
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
|
||||||
|
|
||||||
# Now we want to do exponential backoff, and not count this as an actual page retry
|
# Now we want to do exponential backoff, and not count this as an actual page retry
|
||||||
@ -209,7 +207,7 @@ async def process_page(args, session: aiohttp.ClientSession, worker_id: int, pdf
|
|||||||
raise ValueError(f"Could not process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts")
|
raise ValueError(f"Could not process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts")
|
||||||
|
|
||||||
|
|
||||||
async def process_pdf(args, session: aiohttp.ClientSession, worker_id: int, pdf_s3_path: str):
|
async def process_pdf(args, session: httpx.AsyncClient, worker_id: int, pdf_s3_path: str):
|
||||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||||
# TODO Switch to aioboto3 or something
|
# TODO Switch to aioboto3 or something
|
||||||
data = await asyncio.to_thread(lambda: get_s3_bytes_with_backoff(pdf_s3, pdf_s3_path))
|
data = await asyncio.to_thread(lambda: get_s3_bytes_with_backoff(pdf_s3, pdf_s3_path))
|
||||||
@ -306,8 +304,7 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
|
|||||||
await tracker.clear_work(worker_id)
|
await tracker.clear_work(worker_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=600),
|
async with httpx.AsyncClient(timeout=600, limits=httpx.Limits(max_connections=1000)) as session:
|
||||||
connector=aiohttp.TCPConnector(limit=1000)) as session:
|
|
||||||
async with asyncio.TaskGroup() as tg:
|
async with asyncio.TaskGroup() as tg:
|
||||||
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in work_item.s3_work_paths]
|
dolma_tasks = [tg.create_task(process_pdf(args, session, worker_id, pdf)) for pdf in work_item.s3_work_paths]
|
||||||
logger.info(f"Created all tasks for {work_item.hash}")
|
logger.info(f"Created all tasks for {work_item.hash}")
|
||||||
@ -466,13 +463,14 @@ async def sglang_server_ready():
|
|||||||
|
|
||||||
for attempt in range(1, max_attempts + 1):
|
for attempt in range(1, max_attempts + 1):
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with httpx.AsyncClient() as session:
|
||||||
async with session.get(url) as response:
|
response = await session.get(url)
|
||||||
if response.status == 200:
|
|
||||||
|
if response.status_code == 200:
|
||||||
logger.info("sglang server is ready.")
|
logger.info("sglang server is ready.")
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.info(f"Attempt {attempt}: Unexpected status code {response.status}")
|
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Attempt {attempt}: {e}")
|
logger.warning(f"Attempt {attempt}: {e}")
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@ _MAJOR = "0"
|
|||||||
_MINOR = "1"
|
_MINOR = "1"
|
||||||
# On main and in a nightly release the patch should be one ahead of the last
|
# On main and in a nightly release the patch should be one ahead of the last
|
||||||
# released build.
|
# released build.
|
||||||
_PATCH = "33"
|
_PATCH = "34"
|
||||||
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
|
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
|
||||||
# https://semver.org/#is-v123-a-semantic-version for the semantics.
|
# https://semver.org/#is-v123-a-semantic-version for the semantics.
|
||||||
_SUFFIX = ""
|
_SUFFIX = ""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user