diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 36689c3..8c4daa3 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -567,10 +567,6 @@ async def worker(args, work_queue: WorkQueue, semaphore, worker_id): async def vllm_server_task(model_name_or_path, args, semaphore): - # Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory - gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB - mem_fraction_arg = ["--gpu-memory-utilization", "0.80"] if gpu_memory < 60 else [] - cmd = [ "vllm", "serve", @@ -582,8 +578,11 @@ async def vllm_server_task(model_name_or_path, args, semaphore): "warning", "--served-model-name", "Qwen/Qwen2-VL-7B-Instruct", + "--tensor-parallel-size", + str(args.tensor_parallel_size), + "--data-parallel-size", + str(args.data_parallel_size), ] - cmd.extend(mem_fraction_arg) proc = await asyncio.create_subprocess_exec( *cmd, @@ -623,7 +622,7 @@ async def vllm_server_task(model_name_or_path, args, semaphore): if match: last_running_req = int(match.group(1)) - match = re.search(r'(?:Waiting|Pending):\s*(\d+)', line) + match = re.search(r"(?:Waiting|Pending):\s*(\d+)", line) if match: last_queue_req = int(match.group(1)) logger.info(f"vllm running req: {last_running_req} queue req: {last_queue_req}") @@ -1025,6 +1024,8 @@ async def main(): parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run") parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job") parser.add_argument("--port", type=int, default=30024, help="Port to use for the VLLM server") + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="Tensor parallel size for vLLM") + parser.add_argument("--data-parallel-size", "-dp", type=int, default=1, help="Data parallel size for vLLM") args = parser.parse_args() global workspace_s3, pdf_s3