diff --git a/README.md b/README.md index 024b641..e385e70 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,26 @@ For example: ```bash python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace --pdfs s3://my_s3_bucket/jakep/gnarly_pdfs/*.pdf --beaker --beaker_gpus 4 ``` +### Using DeepInfra +Signup at [DeepInfra](https://deepinfra.com/) and get your API key from the DeepInfra dashboard. +Store the API key as an environment variable. +```bash +export DEEPINFRA_API_KEY="your-api-key-here" +``` +#### Run olmOCR with the DeepInfra server endpoint: +```bash +python -m olmocr.pipeline ./localworkspace \ + --server https://api.deepinfra.com/v1/openai \ + --api_key $DEEPINFRA_API_KEY \ + --model allenai/olmOCR-7B-0725-FP8 \ + --markdown \ + --pdfs path/to/your/*.pdf +``` +- `--server`: DeepInfra's OpenAI-compatible endpoint: `https://api.deepinfra.com/v1/openai` +- `--api_key`: Your DeepInfra API key +- `--model`: The model identifier on DeepInfra: `allenai/olmOCR-7B-0725-FP8` +- Other arguments work the same as with local inference + ### Using Docker diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 04a2170..1c5febb 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -11,6 +11,7 @@ import os import random import re import shutil +import ssl import sys import tempfile import time @@ -104,7 +105,7 @@ class PageResult: is_fallback: bool -async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, image_rotation: int = 0) -> dict: +async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, image_rotation: int = 0, model_name: str = "olmocr") -> dict: MAX_TOKENS = 4500 assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query" @@ -132,7 +133,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return { - "model": "olmocr", + "model": model_name, "messages": [ { "role": "user", @@ -151,29 +152,44 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_ # It feels strange perhaps, but httpx and aiohttp are very complex beasts # Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed # that at the scale of 100M+ requests, that they deadlock in different strange ways -async def apost(url, json_data): +async def apost(url, json_data, api_key=None): parsed_url = urlparse(url) host = parsed_url.hostname - port = parsed_url.port or 80 + # Default to 443 for HTTPS, 80 for HTTP + if parsed_url.scheme == "https": + port = parsed_url.port or 443 + use_ssl = True + else: + port = parsed_url.port or 80 + use_ssl = False path = parsed_url.path or "/" writer = None try: - reader, writer = await asyncio.open_connection(host, port) + if use_ssl: + ssl_context = ssl.create_default_context() + reader, writer = await asyncio.open_connection(host, port, ssl=ssl_context) + else: + reader, writer = await asyncio.open_connection(host, port) json_payload = json.dumps(json_data) - request = ( - f"POST {path} HTTP/1.1\r\n" - f"Host: {host}\r\n" - f"Content-Type: application/json\r\n" - f"Content-Length: {len(json_payload)}\r\n" - f"Connection: close\r\n\r\n" - f"{json_payload}" - ) + + headers = [ + f"POST {path} HTTP/1.1", + f"Host: {host}", + f"Content-Type: application/json", + f"Content-Length: {len(json_payload)}", + ] + + if api_key: + headers.append(f"Authorization: Bearer {api_key}") + + headers.append("Connection: close") + + request = "\r\n".join(headers) + "\r\n\r\n" + json_payload writer.write(request.encode()) await writer.drain() - # Read status line status_line = await reader.readline() if not status_line: raise ConnectionError("No response from server") @@ -214,7 +230,13 @@ async def apost(url, json_data): async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult: if args.server: - COMPLETION_URL = f"{args.server.rstrip('/')}/v1/chat/completions" + server_url = args.server.rstrip("/") + # Check if the server URL already contains '/v1/openai' (DeepInfra case) + if "/v1/openai" in server_url: + COMPLETION_URL = f"{server_url}/chat/completions" + else: + COMPLETION_URL = f"{server_url}/v1/chat/completions" + logger.debug(f"Using completion URL: {COMPLETION_URL}") else: COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions" MAX_RETRIES = args.max_page_retries @@ -227,11 +249,19 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: while attempt < MAX_RETRIES: lookup_attempt = min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1) + # For external servers (like DeepInfra), use the model name from args + # For local inference, always use 'olmocr' + if args.server and hasattr(args, "model"): + model_name = args.model + else: + model_name = "olmocr" + query = await build_page_query( pdf_local_path, page_num, args.target_longest_image_dim, image_rotation=cumulative_rotation, + model_name=model_name, ) # Change temperature as number of attempts increases to overcome repetition issues at expense of quality query["temperature"] = TEMPERATURE_BY_ATTEMPT[lookup_attempt] @@ -245,7 +275,12 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: logger.debug(f"Built page query for {pdf_orig_path}-{page_num}") try: - status_code, response_body = await apost(COMPLETION_URL, json_data=query) + # Passing API key only for external servers that need authentication + if args.server and hasattr(args, "api_key"): + api_key = args.api_key + else: + api_key = None + status_code, response_body = await apost(COMPLETION_URL, json_data=query, api_key=api_key) if status_code == 400: raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response") @@ -737,14 +772,23 @@ async def vllm_server_ready(args): max_attempts = 300 delay_sec = 1 if args.server: - url = f"{args.server.rstrip('/')}/v1/models" + # Check if the server URL already contains '/v1/openai' (DeepInfra case) + server_url = args.server.rstrip("/") + if "/v1/openai" in server_url: + url = f"{server_url}/models" + else: + url = f"{server_url}/v1/models" else: url = f"http://localhost:{BASE_SERVER_PORT}/v1/models" for attempt in range(1, max_attempts + 1): try: + headers = {} + if args.server and hasattr(args, "api_key") and args.api_key: + headers["Authorization"] = f"Bearer {args.api_key}" + async with httpx.AsyncClient() as session: - response = await session.get(url) + response = await session.get(url, headers=headers) if response.status_code == 200: logger.info("vllm server is ready.") @@ -1064,6 +1108,7 @@ async def main(): parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288) parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1) parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs") + parser.add_argument("--api_key", type=str, default=None, help="API key for authenticated remote servers (e.g., DeepInfra)") vllm_group = parser.add_argument_group( "VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM." diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 600753d..1541639 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -209,7 +209,7 @@ class TestRotationCorrection: # Counter to track number of API calls call_count = 0 - async def mock_apost(url, json_data): + async def mock_apost(url, json_data, api_key=None): nonlocal call_count call_count += 1 @@ -268,9 +268,9 @@ This is the corrected text from the document.""" build_page_query_calls = [] original_build_page_query = build_page_query - async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): + async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"): build_page_query_calls.append(image_rotation) - return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) + return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name) with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.tracker", mock_tracker): @@ -311,7 +311,7 @@ This is the corrected text from the document.""" # Counter to track number of API calls call_count = 0 - async def mock_apost(url, json_data): + async def mock_apost(url, json_data, api_key=None): nonlocal call_count call_count += 1 @@ -376,9 +376,9 @@ Document is now correctly oriented after 180 degree rotation.""" build_page_query_calls = [] original_build_page_query = build_page_query - async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): + async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"): build_page_query_calls.append(image_rotation) - return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) + return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name) with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.tracker", mock_tracker): @@ -420,7 +420,7 @@ Document is now correctly oriented after 180 degree rotation.""" # Counter to track number of API calls call_count = 0 - async def mock_apost(url, json_data): + async def mock_apost(url, json_data, api_key=None): nonlocal call_count call_count += 1 @@ -482,9 +482,9 @@ Document correctly oriented at 90 degrees total rotation.""" build_page_query_calls = [] original_build_page_query = build_page_query - async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): + async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"): build_page_query_calls.append(image_rotation) - return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) + return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name) with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.tracker", mock_tracker):