mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-27 07:05:05 +00:00
External vLLM
This commit is contained in:
parent
4dbf951f45
commit
b8a2b92174
@ -213,7 +213,10 @@ async def apost(url, json_data):
|
||||
|
||||
|
||||
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||
COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions"
|
||||
if args.external_vllm_url:
|
||||
COMPLETION_URL = f"{args.external_vllm_url.rstrip('/')}/v1/chat/completions"
|
||||
else:
|
||||
COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions"
|
||||
MAX_RETRIES = args.max_page_retries
|
||||
MODEL_MAX_CONTEXT = 16384
|
||||
TEMPERATURE_BY_ATTEMPT = [0.1, 0.1, 0.2, 0.3, 0.5, 0.8, 0.9, 1.0]
|
||||
@ -607,6 +610,7 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non
|
||||
if unknown_args:
|
||||
cmd.extend(unknown_args)
|
||||
|
||||
breakpoint()
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
@ -730,10 +734,13 @@ async def vllm_server_host(model_name_or_path, args, semaphore, unknown_args=Non
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
async def vllm_server_ready():
|
||||
async def vllm_server_ready(args):
|
||||
max_attempts = 300
|
||||
delay_sec = 1
|
||||
url = f"http://localhost:{BASE_SERVER_PORT}/v1/models"
|
||||
if args.external_vllm_url:
|
||||
url = f"{args.external_vllm_url.rstrip('/')}/v1/models"
|
||||
else:
|
||||
url = f"http://localhost:{BASE_SERVER_PORT}/v1/models"
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
@ -1069,6 +1076,9 @@ async def main():
|
||||
vllm_group.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="Tensor parallel size for vLLM")
|
||||
vllm_group.add_argument("--data-parallel-size", "-dp", type=int, default=1, help="Data parallel size for vLLM")
|
||||
vllm_group.add_argument("--port", type=int, default=30024, help="Port to use for the VLLM server")
|
||||
vllm_group.add_argument(
|
||||
"--external-vllm-url", type=str, help="URL of external vLLM server (e.g., http://hostname:port). If provided, skips spawning local vLLM instance"
|
||||
)
|
||||
|
||||
# Beaker/job running stuff
|
||||
beaker_group = parser.add_argument_group("beaker/cluster execution")
|
||||
@ -1207,12 +1217,17 @@ async def main():
|
||||
|
||||
# If you get this far, then you are doing inference and need a GPU
|
||||
# check_sglang_version()
|
||||
check_torch_gpu_available()
|
||||
if not args.external_vllm_url:
|
||||
check_torch_gpu_available()
|
||||
|
||||
logger.info(f"Starting pipeline with PID {os.getpid()}")
|
||||
|
||||
# Download the model before you do anything else
|
||||
model_name_or_path = await download_model(args.model)
|
||||
if args.external_vllm_url:
|
||||
logger.info(f"Using external vLLM server at {args.external_vllm_url}")
|
||||
model_name_or_path = None
|
||||
else:
|
||||
model_name_or_path = await download_model(args.model)
|
||||
|
||||
# Initialize the work queue
|
||||
qsize = await work_queue.initialize_queue()
|
||||
@ -1226,9 +1241,12 @@ async def main():
|
||||
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args))
|
||||
# Start local vLLM instance if not using external one
|
||||
vllm_server = None
|
||||
if not args.external_vllm_url:
|
||||
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args))
|
||||
|
||||
await vllm_server_ready()
|
||||
await vllm_server_ready(args)
|
||||
|
||||
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
|
||||
|
||||
@ -1241,11 +1259,16 @@ async def main():
|
||||
# Wait for all worker tasks to finish
|
||||
await asyncio.gather(*worker_tasks)
|
||||
|
||||
vllm_server.cancel()
|
||||
# Cancel vLLM server if it was started
|
||||
if vllm_server is not None:
|
||||
vllm_server.cancel()
|
||||
metrics_task.cancel()
|
||||
|
||||
# Wait for cancelled tasks to complete
|
||||
await asyncio.gather(vllm_server, metrics_task, return_exceptions=True)
|
||||
tasks_to_wait = [metrics_task]
|
||||
if vllm_server is not None:
|
||||
tasks_to_wait.append(vllm_server)
|
||||
await asyncio.gather(*tasks_to_wait, return_exceptions=True)
|
||||
|
||||
# Output final metrics summary
|
||||
metrics_summary = metrics.get_metrics_summary()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user