External vLLM

This commit is contained in:
Haydn Jones 2025-08-20 19:21:38 -04:00
parent 4dbf951f45
commit b8a2b92174

View File

@ -213,7 +213,10 @@ async def apost(url, json_data):
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions"
if args.external_vllm_url:
COMPLETION_URL = f"{args.external_vllm_url.rstrip('/')}/v1/chat/completions"
else:
COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries
MODEL_MAX_CONTEXT = 16384
TEMPERATURE_BY_ATTEMPT = [0.1, 0.1, 0.2, 0.3, 0.5, 0.8, 0.9, 1.0]
@ -607,6 +610,7 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non
if unknown_args:
cmd.extend(unknown_args)
breakpoint()
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
@ -730,10 +734,13 @@ async def vllm_server_host(model_name_or_path, args, semaphore, unknown_args=Non
sys.exit(1)
async def vllm_server_ready():
async def vllm_server_ready(args):
max_attempts = 300
delay_sec = 1
url = f"http://localhost:{BASE_SERVER_PORT}/v1/models"
if args.external_vllm_url:
url = f"{args.external_vllm_url.rstrip('/')}/v1/models"
else:
url = f"http://localhost:{BASE_SERVER_PORT}/v1/models"
for attempt in range(1, max_attempts + 1):
try:
@ -1069,6 +1076,9 @@ async def main():
vllm_group.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="Tensor parallel size for vLLM")
vllm_group.add_argument("--data-parallel-size", "-dp", type=int, default=1, help="Data parallel size for vLLM")
vllm_group.add_argument("--port", type=int, default=30024, help="Port to use for the VLLM server")
vllm_group.add_argument(
"--external-vllm-url", type=str, help="URL of external vLLM server (e.g., http://hostname:port). If provided, skips spawning local vLLM instance"
)
# Beaker/job running stuff
beaker_group = parser.add_argument_group("beaker/cluster execution")
@ -1207,12 +1217,17 @@ async def main():
# If you get this far, then you are doing inference and need a GPU
# check_sglang_version()
check_torch_gpu_available()
if not args.external_vllm_url:
check_torch_gpu_available()
logger.info(f"Starting pipeline with PID {os.getpid()}")
# Download the model before you do anything else
model_name_or_path = await download_model(args.model)
if args.external_vllm_url:
logger.info(f"Using external vLLM server at {args.external_vllm_url}")
model_name_or_path = None
else:
model_name_or_path = await download_model(args.model)
# Initialize the work queue
qsize = await work_queue.initialize_queue()
@ -1226,9 +1241,12 @@ async def main():
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args))
# Start local vLLM instance if not using external one
vllm_server = None
if not args.external_vllm_url:
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args))
await vllm_server_ready()
await vllm_server_ready(args)
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
@ -1241,11 +1259,16 @@ async def main():
# Wait for all worker tasks to finish
await asyncio.gather(*worker_tasks)
vllm_server.cancel()
# Cancel vLLM server if it was started
if vllm_server is not None:
vllm_server.cancel()
metrics_task.cancel()
# Wait for cancelled tasks to complete
await asyncio.gather(vllm_server, metrics_task, return_exceptions=True)
tasks_to_wait = [metrics_task]
if vllm_server is not None:
tasks_to_wait.append(vllm_server)
await asyncio.gather(*tasks_to_wait, return_exceptions=True)
# Output final metrics summary
metrics_summary = metrics.get_metrics_summary()