mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 08:43:32 +00:00
Merge pull request #291 from haydn-jones/main
Forward unknown args to vLLM
This commit is contained in:
commit
5e991b67e5
@ -568,7 +568,7 @@ async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
|
|||||||
semaphore.release()
|
semaphore.release()
|
||||||
|
|
||||||
|
|
||||||
async def vllm_server_task(model_name_or_path, args, semaphore):
|
async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=None):
|
||||||
cmd = [
|
cmd = [
|
||||||
"vllm",
|
"vllm",
|
||||||
"serve",
|
"serve",
|
||||||
@ -592,6 +592,9 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
|
|||||||
if args.max_model_len is not None:
|
if args.max_model_len is not None:
|
||||||
cmd.extend(["--max-model-len", str(args.max_model_len)])
|
cmd.extend(["--max-model-len", str(args.max_model_len)])
|
||||||
|
|
||||||
|
if unknown_args:
|
||||||
|
cmd.extend(unknown_args)
|
||||||
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
*cmd,
|
*cmd,
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
@ -681,12 +684,12 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
|
|||||||
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
|
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
async def vllm_server_host(model_name_or_path, args, semaphore):
|
async def vllm_server_host(model_name_or_path, args, semaphore, unknown_args=None):
|
||||||
MAX_RETRIES = 5
|
MAX_RETRIES = 5
|
||||||
retry = 0
|
retry = 0
|
||||||
|
|
||||||
while retry < MAX_RETRIES:
|
while retry < MAX_RETRIES:
|
||||||
await vllm_server_task(model_name_or_path, args, semaphore)
|
await vllm_server_task(model_name_or_path, args, semaphore, unknown_args)
|
||||||
logger.warning("VLLM server task ended")
|
logger.warning("VLLM server task ended")
|
||||||
retry += 1
|
retry += 1
|
||||||
|
|
||||||
@ -996,7 +999,7 @@ def print_stats(args, root_work_queue):
|
|||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline")
|
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"workspace",
|
"workspace",
|
||||||
help="The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ",
|
help="The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ",
|
||||||
@ -1028,7 +1031,10 @@ async def main():
|
|||||||
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
|
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
|
||||||
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
|
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
|
||||||
|
|
||||||
vllm_group = parser.add_argument_group("VLLM Forwarded arguments")
|
vllm_group = parser.add_argument_group(
|
||||||
|
"VLLM arguments",
|
||||||
|
"These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM."
|
||||||
|
)
|
||||||
vllm_group.add_argument(
|
vllm_group.add_argument(
|
||||||
"--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve)."
|
"--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve)."
|
||||||
)
|
)
|
||||||
@ -1049,7 +1055,7 @@ async def main():
|
|||||||
beaker_group.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
|
beaker_group.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
|
||||||
beaker_group.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
|
beaker_group.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args, unknown_args = parser.parse_known_args()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"If you run out of GPU memory during start-up or get 'KV cache is larger than available memory' errors, retry with lower values, e.g. --gpu_memory_utilization 0.80 --max_model_len 16384"
|
"If you run out of GPU memory during start-up or get 'KV cache is larger than available memory' errors, retry with lower values, e.g. --gpu_memory_utilization 0.80 --max_model_len 16384"
|
||||||
@ -1194,7 +1200,7 @@ async def main():
|
|||||||
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
|
||||||
semaphore = asyncio.Semaphore(1)
|
semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore))
|
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args))
|
||||||
|
|
||||||
await vllm_server_ready()
|
await vllm_server_ready()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user