mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-02 19:13:53 +00:00
new version of sglang, server restarts, semaphore timeouts
This commit is contained in:
parent
918e2f3542
commit
102c0e4cfc
@ -236,8 +236,10 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
|
||||
async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
|
||||
MAX_RETRIES = 3
|
||||
|
||||
attempt = 0
|
||||
|
||||
for attempt in range(1, MAX_RETRIES + 1):
|
||||
while attempt < MAX_RETRIES:
|
||||
query = await build_page_query(
|
||||
pdf_local_path,
|
||||
page_num,
|
||||
@ -267,11 +269,20 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
|
||||
output_tokens=base_response_data["usage"].get("completion_tokens", 0)
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}:: {e}")
|
||||
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
|
||||
|
||||
# Now we want to do exponential backoff, and not count this as an actual page retry
|
||||
# Page retrys are supposed to be for fixing bad results from the model, but actual requests to sglang
|
||||
# are supposed to work. Probably this means that the server is just restarting
|
||||
logger.info(f"Sleeping for 5 seconds on {pdf_s3_path}-{page_num} to allow server restart")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"JSON decode error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
|
||||
attempt += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}:: {e}")
|
||||
logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
|
||||
attempt += 1
|
||||
|
||||
if attempt >= MAX_RETRIES:
|
||||
logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.")
|
||||
@ -429,25 +440,41 @@ async def sglang_server_task(args, semaphore):
|
||||
|
||||
atexit.register(_kill_proc)
|
||||
|
||||
last_queue_req = None # To track transitions
|
||||
last_running_req, last_queue_req = 0, 0 # To track transitions
|
||||
can_release_automatically = False
|
||||
last_semaphore_release = time.time()
|
||||
async def process_line(line):
|
||||
sglang_logger.info(line)
|
||||
|
||||
match = re.search(r'#running-req: (\d+)', line)
|
||||
if match:
|
||||
last_running_req = int(match.group(1))
|
||||
|
||||
if last_running_req > 0:
|
||||
can_release_automatically = True
|
||||
|
||||
# Parse the line and update semaphore if necessary
|
||||
match = re.search(r'#queue-req: (\d+)', line)
|
||||
if match:
|
||||
queue_req = int(match.group(1))
|
||||
logger.info(f"sglang queue req: {queue_req}")
|
||||
logger.info(f"sglang running req: {last_running_req} queue req: {queue_req}")
|
||||
|
||||
nonlocal last_queue_req
|
||||
if last_queue_req is not None and last_queue_req != 0 and queue_req == 0:
|
||||
if last_queue_req != 0 and queue_req == 0:
|
||||
# Release the semaphore when queue_req transitions from non-zero to zero
|
||||
if semaphore.locked():
|
||||
semaphore.release()
|
||||
last_semaphore_release = time.time()
|
||||
logger.info("Semaphore released, allowing a worker to proceed.")
|
||||
|
||||
last_queue_req = queue_req
|
||||
|
||||
# And have a semaphore release automatically if there are no running requests for > 30 seconds
|
||||
if last_running_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
|
||||
semaphore.release()
|
||||
last_semaphore_release = time.time()
|
||||
logger.info("Semaphore released due to timeout, allowing a worker to proceed.")
|
||||
|
||||
async def read_stream(stream):
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
@ -465,6 +492,11 @@ async def sglang_server_task(args, semaphore):
|
||||
await stderr_task
|
||||
|
||||
|
||||
async def sglang_server_host(args, semaphore):
|
||||
while True:
|
||||
await sglang_server_task(args, semaphore)
|
||||
|
||||
|
||||
async def sglang_server_ready():
|
||||
max_attempts = 300
|
||||
delay_sec = 1
|
||||
@ -528,7 +560,7 @@ async def main():
|
||||
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
sglang_server = asyncio.create_task(sglang_server_task(args, semaphore))
|
||||
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
|
||||
|
||||
work_queue = await load_pdf_work_queue(args)
|
||||
logger.info(f"Work queue prepared with {work_queue.qsize()} items")
|
||||
|
||||
@ -68,7 +68,7 @@ dev = [
|
||||
]
|
||||
|
||||
inference = [
|
||||
"sglang[all]"
|
||||
{ git = "https://github.com/sgl-project/sglang.git", rev = "eff468dd5a3d24646560eb044276585f7a11ac3c", subdirectory = "python", extras = ["all"] }
|
||||
]
|
||||
|
||||
train = [
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user