diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index c6f24f9..e4d5d14 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -624,12 +624,12 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non # Shared variables between tasks last_running_req, last_queue_req = 0, 0 - prev_running_req_at_release = 0 # Track running requests at last semaphore release + running_reqs_decreased = False server_printed_ready_message = False last_semaphore_release = time.time() async def process_line(line): - nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message + nonlocal last_running_req, last_queue_req, running_reqs_decreased, last_semaphore_release, server_printed_ready_message server_logger.info(line) if "Detected errors during sampling" in line: @@ -640,12 +640,15 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non server_printed_ready_message = True last_semaphore_release = time.time() - match = re.search(r"Running: (\d+)", line) - if match: - last_running_req = int(match.group(1)) + if match := re.search(r"Running: (\d+)", line): + current_running = int(match.group(1)) + # Check for negative derivative (decrease in running requests), to not overload VLLM + if current_running < last_running_req: + running_reqs_decreased = True + logger.info(f"Running requests decreased: {last_running_req} -> {current_running}") + last_running_req = current_running - match = re.search(r"(?:Waiting|Pending):\s*(\d+)", line) - if match: + if match := re.search(r"(?:Waiting|Pending):\s*(\d+)", line): last_queue_req = int(match.group(1)) logger.info(f"vllm running req: {last_running_req} queue req: {last_queue_req}") @@ -661,25 +664,25 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non logger.warning(f"Got {ex} when reading log line from inference server, skipping") async def timeout_task(): - nonlocal last_running_req, last_queue_req, last_semaphore_release, prev_running_req_at_release + nonlocal last_running_req, last_queue_req, last_semaphore_release, running_reqs_decreased try: while True: await asyncio.sleep(1) - + # Check if we should release the semaphore should_release = ( server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked() - and (last_running_req == 0 or last_running_req < prev_running_req_at_release) + and (last_running_req == 0 or running_reqs_decreased) ) - + if should_release: semaphore.release() - prev_running_req_at_release = last_running_req + running_reqs_decreased = False # Reset flag after release last_semaphore_release = time.time() - logger.info(f"Semaphore released, allowing a worker to proceed. Running requests: {last_running_req} (prev: {prev_running_req_at_release})") + logger.info(f"Semaphore released, allowing a worker to proceed. Running requests: {last_running_req}") except asyncio.CancelledError: pass # Clean up if the task is cancelled