Merge pull request #291 from haydn-jones/main

Forward unknown args to vLLM
This commit is contained in:
Jake Poznanski 2025-08-04 11:04:46 -07:00 committed by GitHub
commit 5e991b67e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -568,7 +568,7 @@ async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
semaphore.release()
async def vllm_server_task(model_name_or_path, args, semaphore):
async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=None):
cmd = [
"vllm",
"serve",
@ -592,6 +592,9 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
if args.max_model_len is not None:
cmd.extend(["--max-model-len", str(args.max_model_len)])
if unknown_args:
cmd.extend(unknown_args)
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
@ -681,12 +684,12 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
async def vllm_server_host(model_name_or_path, args, semaphore):
async def vllm_server_host(model_name_or_path, args, semaphore, unknown_args=None):
MAX_RETRIES = 5
retry = 0
while retry < MAX_RETRIES:
await vllm_server_task(model_name_or_path, args, semaphore)
await vllm_server_task(model_name_or_path, args, semaphore, unknown_args)
logger.warning("VLLM server task ended")
retry += 1
@ -996,7 +999,7 @@ def print_stats(args, root_work_queue):
async def main():
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline")
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline.")
parser.add_argument(
"workspace",
help="The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ",
@ -1028,7 +1031,10 @@ async def main():
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
vllm_group = parser.add_argument_group("VLLM Forwarded arguments")
vllm_group = parser.add_argument_group(
"VLLM arguments",
"These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM."
)
vllm_group.add_argument(
"--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve)."
)
@ -1049,7 +1055,7 @@ async def main():
beaker_group.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
beaker_group.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
args = parser.parse_args()
args, unknown_args = parser.parse_known_args()
logger.info(
"If you run out of GPU memory during start-up or get 'KV cache is larger than available memory' errors, retry with lower values, e.g. --gpu_memory_utilization 0.80 --max_model_len 16384"
@ -1194,7 +1200,7 @@ async def main():
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore))
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args))
await vllm_server_ready()