Better queue managmenet again

This commit is contained in:
Jake Poznanski 2025-08-14 16:37:11 +00:00
parent 38679243d7
commit 0a8cd93c0a

View File

@ -624,12 +624,12 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non
# Shared variables between tasks # Shared variables between tasks
last_running_req, last_queue_req = 0, 0 last_running_req, last_queue_req = 0, 0
prev_running_req_at_release = 0 # Track running requests at last semaphore release running_reqs_decreased = False
server_printed_ready_message = False server_printed_ready_message = False
last_semaphore_release = time.time() last_semaphore_release = time.time()
async def process_line(line): async def process_line(line):
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message nonlocal last_running_req, last_queue_req, running_reqs_decreased, last_semaphore_release, server_printed_ready_message
server_logger.info(line) server_logger.info(line)
if "Detected errors during sampling" in line: if "Detected errors during sampling" in line:
@ -640,12 +640,15 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non
server_printed_ready_message = True server_printed_ready_message = True
last_semaphore_release = time.time() last_semaphore_release = time.time()
match = re.search(r"Running: (\d+)", line) if match := re.search(r"Running: (\d+)", line):
if match: current_running = int(match.group(1))
last_running_req = int(match.group(1)) # Check for negative derivative (decrease in running requests), to not overload VLLM
if current_running < last_running_req:
running_reqs_decreased = True
logger.info(f"Running requests decreased: {last_running_req} -> {current_running}")
last_running_req = current_running
match = re.search(r"(?:Waiting|Pending):\s*(\d+)", line) if match := re.search(r"(?:Waiting|Pending):\s*(\d+)", line):
if match:
last_queue_req = int(match.group(1)) last_queue_req = int(match.group(1))
logger.info(f"vllm running req: {last_running_req} queue req: {last_queue_req}") logger.info(f"vllm running req: {last_running_req} queue req: {last_queue_req}")
@ -661,7 +664,7 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non
logger.warning(f"Got {ex} when reading log line from inference server, skipping") logger.warning(f"Got {ex} when reading log line from inference server, skipping")
async def timeout_task(): async def timeout_task():
nonlocal last_running_req, last_queue_req, last_semaphore_release, prev_running_req_at_release nonlocal last_running_req, last_queue_req, last_semaphore_release, running_reqs_decreased
try: try:
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
@ -672,14 +675,14 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non
and last_queue_req == 0 and last_queue_req == 0
and time.time() - last_semaphore_release > 30 and time.time() - last_semaphore_release > 30
and semaphore.locked() and semaphore.locked()
and (last_running_req == 0 or last_running_req < prev_running_req_at_release) and (last_running_req == 0 or running_reqs_decreased)
) )
if should_release: if should_release:
semaphore.release() semaphore.release()
prev_running_req_at_release = last_running_req running_reqs_decreased = False # Reset flag after release
last_semaphore_release = time.time() last_semaphore_release = time.time()
logger.info(f"Semaphore released, allowing a worker to proceed. Running requests: {last_running_req} (prev: {prev_running_req_at_release})") logger.info(f"Semaphore released, allowing a worker to proceed. Running requests: {last_running_req}")
except asyncio.CancelledError: except asyncio.CancelledError:
pass # Clean up if the task is cancelled pass # Clean up if the task is cancelled