new version of sglang, server restarts, semaphore timeouts

This commit is contained in:
Jake Poznanski 2024-11-12 10:49:13 -08:00
parent 918e2f3542
commit 102c0e4cfc
2 changed files with 40 additions and 8 deletions

View File

@ -237,7 +237,9 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
COMPLETION_URL = "http://localhost:30000/v1/chat/completions" COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
MAX_RETRIES = 3 MAX_RETRIES = 3
for attempt in range(1, MAX_RETRIES + 1): attempt = 0
while attempt < MAX_RETRIES:
query = await build_page_query( query = await build_page_query(
pdf_local_path, pdf_local_path,
page_num, page_num,
@ -267,11 +269,20 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
output_tokens=base_response_data["usage"].get("completion_tokens", 0) output_tokens=base_response_data["usage"].get("completion_tokens", 0)
) )
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}:: {e}") logger.warning(f"Client error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
# Now we want to do exponential backoff, and not count this as an actual page retry
# Page retrys are supposed to be for fixing bad results from the model, but actual requests to sglang
# are supposed to work. Probably this means that the server is just restarting
logger.info(f"Sleeping for 5 seconds on {pdf_s3_path}-{page_num} to allow server restart")
await asyncio.sleep(5)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning(f"JSON decode error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}") logger.warning(f"JSON decode error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
attempt += 1
except Exception as e: except Exception as e:
logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}:: {e}") logger.warning(f"Unexpected error on attempt {attempt} for {pdf_s3_path}-{page_num}: {e}")
attempt += 1
if attempt >= MAX_RETRIES: if attempt >= MAX_RETRIES:
logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.") logger.error(f"Failed to process {pdf_s3_path}-{page_num} after {MAX_RETRIES} attempts.")
@ -429,25 +440,41 @@ async def sglang_server_task(args, semaphore):
atexit.register(_kill_proc) atexit.register(_kill_proc)
last_queue_req = None # To track transitions last_running_req, last_queue_req = 0, 0 # To track transitions
can_release_automatically = False
last_semaphore_release = time.time()
async def process_line(line): async def process_line(line):
sglang_logger.info(line) sglang_logger.info(line)
match = re.search(r'#running-req: (\d+)', line)
if match:
last_running_req = int(match.group(1))
if last_running_req > 0:
can_release_automatically = True
# Parse the line and update semaphore if necessary # Parse the line and update semaphore if necessary
match = re.search(r'#queue-req: (\d+)', line) match = re.search(r'#queue-req: (\d+)', line)
if match: if match:
queue_req = int(match.group(1)) queue_req = int(match.group(1))
logger.info(f"sglang queue req: {queue_req}") logger.info(f"sglang running req: {last_running_req} queue req: {queue_req}")
nonlocal last_queue_req nonlocal last_queue_req
if last_queue_req is not None and last_queue_req != 0 and queue_req == 0: if last_queue_req != 0 and queue_req == 0:
# Release the semaphore when queue_req transitions from non-zero to zero # Release the semaphore when queue_req transitions from non-zero to zero
if semaphore.locked(): if semaphore.locked():
semaphore.release() semaphore.release()
last_semaphore_release = time.time()
logger.info("Semaphore released, allowing a worker to proceed.") logger.info("Semaphore released, allowing a worker to proceed.")
last_queue_req = queue_req last_queue_req = queue_req
# And have a semaphore release automatically if there are no running requests for > 30 seconds
if last_running_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
semaphore.release()
last_semaphore_release = time.time()
logger.info("Semaphore released due to timeout, allowing a worker to proceed.")
async def read_stream(stream): async def read_stream(stream):
while True: while True:
line = await stream.readline() line = await stream.readline()
@ -465,6 +492,11 @@ async def sglang_server_task(args, semaphore):
await stderr_task await stderr_task
async def sglang_server_host(args, semaphore):
while True:
await sglang_server_task(args, semaphore)
async def sglang_server_ready(): async def sglang_server_ready():
max_attempts = 300 max_attempts = 300
delay_sec = 1 delay_sec = 1
@ -528,7 +560,7 @@ async def main():
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests # As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1) semaphore = asyncio.Semaphore(1)
sglang_server = asyncio.create_task(sglang_server_task(args, semaphore)) sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
work_queue = await load_pdf_work_queue(args) work_queue = await load_pdf_work_queue(args)
logger.info(f"Work queue prepared with {work_queue.qsize()} items") logger.info(f"Work queue prepared with {work_queue.qsize()} items")

View File

@ -68,7 +68,7 @@ dev = [
] ]
inference = [ inference = [
"sglang[all]" { git = "https://github.com/sgl-project/sglang.git", rev = "eff468dd5a3d24646560eb044276585f7a11ac3c", subdirectory = "python", extras = ["all"] }
] ]
train = [ train = [