mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 16:22:29 +00:00
Merge pull request #291 from haydn-jones/main
Forward unknown args to vLLM
This commit is contained in:
commit
5e991b67e5
@ -568,7 +568,7 @@ async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
|
||||
semaphore.release()
|
||||
|
||||
|
||||
async def vllm_server_task(model_name_or_path, args, semaphore):
|
||||
async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=None):
|
||||
cmd = [
|
||||
"vllm",
|
||||
"serve",
|
||||
@ -592,6 +592,9 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
|
||||
if args.max_model_len is not None:
|
||||
cmd.extend(["--max-model-len", str(args.max_model_len)])
|
||||
|
||||
if unknown_args:
|
||||
cmd.extend(unknown_args)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
@ -681,12 +684,12 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
|
||||
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
|
||||
|
||||
|
||||
async def vllm_server_host(model_name_or_path, args, semaphore):
|
||||
async def vllm_server_host(model_name_or_path, args, semaphore, unknown_args=None):
|
||||
MAX_RETRIES = 5
|
||||
retry = 0
|
||||
|
||||
while retry < MAX_RETRIES:
|
||||
await vllm_server_task(model_name_or_path, args, semaphore)
|
||||
await vllm_server_task(model_name_or_path, args, semaphore, unknown_args)
|
||||
logger.warning("VLLM server task ended")
|
||||
retry += 1
|
||||
|
||||
@ -996,7 +999,7 @@ def print_stats(args, root_work_queue):
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline")
|
||||
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline.")
|
||||
parser.add_argument(
|
||||
"workspace",
|
||||
help="The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ",
|
||||
@ -1028,7 +1031,10 @@ async def main():
|
||||
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
|
||||
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
|
||||
|
||||
vllm_group = parser.add_argument_group("VLLM Forwarded arguments")
|
||||
vllm_group = parser.add_argument_group(
|
||||
"VLLM arguments",
|
||||
"These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM."
|
||||
)
|
||||
vllm_group.add_argument(
|
||||
"--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve)."
|
||||
)
|
||||
@ -1049,7 +1055,7 @@ async def main():
|
||||
beaker_group.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
|
||||
beaker_group.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
|
||||
|
||||
args = parser.parse_args()
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
|
||||
logger.info(
|
||||
"If you run out of GPU memory during start-up or get 'KV cache is larger than available memory' errors, retry with lower values, e.g. --gpu_memory_utilization 0.80 --max_model_len 16384"
|
||||
@ -1194,7 +1200,7 @@ async def main():
|
||||
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore))
|
||||
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args))
|
||||
|
||||
await vllm_server_ready()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user