diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index a7958af..f51ae5e 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -568,7 +568,7 @@ async def worker(args, work_queue: WorkQueue, semaphore, worker_id): semaphore.release() -async def vllm_server_task(model_name_or_path, args, semaphore): +async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=None): cmd = [ "vllm", "serve", @@ -592,6 +592,9 @@ async def vllm_server_task(model_name_or_path, args, semaphore): if args.max_model_len is not None: cmd.extend(["--max-model-len", str(args.max_model_len)]) + if unknown_args: + cmd.extend(unknown_args) + proc = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, @@ -681,12 +684,12 @@ async def vllm_server_task(model_name_or_path, args, semaphore): await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True) -async def vllm_server_host(model_name_or_path, args, semaphore): +async def vllm_server_host(model_name_or_path, args, semaphore, unknown_args=None): MAX_RETRIES = 5 retry = 0 while retry < MAX_RETRIES: - await vllm_server_task(model_name_or_path, args, semaphore) + await vllm_server_task(model_name_or_path, args, semaphore, unknown_args) logger.warning("VLLM server task ended") retry += 1 @@ -996,7 +999,7 @@ def print_stats(args, root_work_queue): async def main(): - parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline") + parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline.") parser.add_argument( "workspace", help="The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ", @@ -1028,7 +1031,10 @@ async def main(): parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1) parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs") - vllm_group = parser.add_argument_group("VLLM Forwarded arguments") + vllm_group = parser.add_argument_group( + "VLLM arguments", + "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM." + ) vllm_group.add_argument( "--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve)." ) @@ -1049,7 +1055,7 @@ async def main(): beaker_group.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run") beaker_group.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job") - args = parser.parse_args() + args, unknown_args = parser.parse_known_args() logger.info( "If you run out of GPU memory during start-up or get 'KV cache is larger than available memory' errors, retry with lower values, e.g. --gpu_memory_utilization 0.80 --max_model_len 16384" @@ -1194,7 +1200,7 @@ async def main(): # As soon as one worker is no longer saturating the gpu, the next one can start sending requests semaphore = asyncio.Semaphore(1) - vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore)) + vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args)) await vllm_server_ready()