Merge pull request #344 from allenai/amanr/deepinfra

DeepInfra Support
This commit is contained in:
Jake Poznanski 2025-09-29 10:03:29 -07:00 committed by GitHub
commit f0caa188ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 92 additions and 27 deletions

View File

@ -249,6 +249,26 @@ For example:
```bash ```bash
python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace --pdfs s3://my_s3_bucket/jakep/gnarly_pdfs/*.pdf --beaker --beaker_gpus 4 python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace --pdfs s3://my_s3_bucket/jakep/gnarly_pdfs/*.pdf --beaker --beaker_gpus 4
``` ```
### Using DeepInfra
Signup at [DeepInfra](https://deepinfra.com/) and get your API key from the DeepInfra dashboard.
Store the API key as an environment variable.
```bash
export DEEPINFRA_API_KEY="your-api-key-here"
```
#### Run olmOCR with the DeepInfra server endpoint:
```bash
python -m olmocr.pipeline ./localworkspace \
--server https://api.deepinfra.com/v1/openai \
--api_key $DEEPINFRA_API_KEY \
--model allenai/olmOCR-7B-0725-FP8 \
--markdown \
--pdfs path/to/your/*.pdf
```
- `--server`: DeepInfra's OpenAI-compatible endpoint: `https://api.deepinfra.com/v1/openai`
- `--api_key`: Your DeepInfra API key
- `--model`: The model identifier on DeepInfra: `allenai/olmOCR-7B-0725-FP8`
- Other arguments work the same as with local inference
### Using Docker ### Using Docker

View File

@ -11,6 +11,7 @@ import os
import random import random
import re import re
import shutil import shutil
import ssl
import sys import sys
import tempfile import tempfile
import time import time
@ -104,7 +105,7 @@ class PageResult:
is_fallback: bool is_fallback: bool
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, image_rotation: int = 0) -> dict: async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, image_rotation: int = 0, model_name: str = "olmocr") -> dict:
MAX_TOKENS = 4500 MAX_TOKENS = 4500
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query" assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
@ -132,7 +133,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return { return {
"model": "olmocr", "model": model_name,
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@ -151,29 +152,44 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
# It feels strange perhaps, but httpx and aiohttp are very complex beasts # It feels strange perhaps, but httpx and aiohttp are very complex beasts
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed # Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
# that at the scale of 100M+ requests, that they deadlock in different strange ways # that at the scale of 100M+ requests, that they deadlock in different strange ways
async def apost(url, json_data): async def apost(url, json_data, api_key=None):
parsed_url = urlparse(url) parsed_url = urlparse(url)
host = parsed_url.hostname host = parsed_url.hostname
port = parsed_url.port or 80 # Default to 443 for HTTPS, 80 for HTTP
if parsed_url.scheme == "https":
port = parsed_url.port or 443
use_ssl = True
else:
port = parsed_url.port or 80
use_ssl = False
path = parsed_url.path or "/" path = parsed_url.path or "/"
writer = None writer = None
try: try:
reader, writer = await asyncio.open_connection(host, port) if use_ssl:
ssl_context = ssl.create_default_context()
reader, writer = await asyncio.open_connection(host, port, ssl=ssl_context)
else:
reader, writer = await asyncio.open_connection(host, port)
json_payload = json.dumps(json_data) json_payload = json.dumps(json_data)
request = (
f"POST {path} HTTP/1.1\r\n" headers = [
f"Host: {host}\r\n" f"POST {path} HTTP/1.1",
f"Content-Type: application/json\r\n" f"Host: {host}",
f"Content-Length: {len(json_payload)}\r\n" f"Content-Type: application/json",
f"Connection: close\r\n\r\n" f"Content-Length: {len(json_payload)}",
f"{json_payload}" ]
)
if api_key:
headers.append(f"Authorization: Bearer {api_key}")
headers.append("Connection: close")
request = "\r\n".join(headers) + "\r\n\r\n" + json_payload
writer.write(request.encode()) writer.write(request.encode())
await writer.drain() await writer.drain()
# Read status line
status_line = await reader.readline() status_line = await reader.readline()
if not status_line: if not status_line:
raise ConnectionError("No response from server") raise ConnectionError("No response from server")
@ -214,7 +230,13 @@ async def apost(url, json_data):
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult: async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
if args.server: if args.server:
COMPLETION_URL = f"{args.server.rstrip('/')}/v1/chat/completions" server_url = args.server.rstrip("/")
# Check if the server URL already contains '/v1/openai' (DeepInfra case)
if "/v1/openai" in server_url:
COMPLETION_URL = f"{server_url}/chat/completions"
else:
COMPLETION_URL = f"{server_url}/v1/chat/completions"
logger.debug(f"Using completion URL: {COMPLETION_URL}")
else: else:
COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions" COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries MAX_RETRIES = args.max_page_retries
@ -227,11 +249,19 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
while attempt < MAX_RETRIES: while attempt < MAX_RETRIES:
lookup_attempt = min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1) lookup_attempt = min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
# For external servers (like DeepInfra), use the model name from args
# For local inference, always use 'olmocr'
if args.server and hasattr(args, "model"):
model_name = args.model
else:
model_name = "olmocr"
query = await build_page_query( query = await build_page_query(
pdf_local_path, pdf_local_path,
page_num, page_num,
args.target_longest_image_dim, args.target_longest_image_dim,
image_rotation=cumulative_rotation, image_rotation=cumulative_rotation,
model_name=model_name,
) )
# Change temperature as number of attempts increases to overcome repetition issues at expense of quality # Change temperature as number of attempts increases to overcome repetition issues at expense of quality
query["temperature"] = TEMPERATURE_BY_ATTEMPT[lookup_attempt] query["temperature"] = TEMPERATURE_BY_ATTEMPT[lookup_attempt]
@ -245,7 +275,12 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
logger.debug(f"Built page query for {pdf_orig_path}-{page_num}") logger.debug(f"Built page query for {pdf_orig_path}-{page_num}")
try: try:
status_code, response_body = await apost(COMPLETION_URL, json_data=query) # Passing API key only for external servers that need authentication
if args.server and hasattr(args, "api_key"):
api_key = args.api_key
else:
api_key = None
status_code, response_body = await apost(COMPLETION_URL, json_data=query, api_key=api_key)
if status_code == 400: if status_code == 400:
raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response") raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response")
@ -737,14 +772,23 @@ async def vllm_server_ready(args):
max_attempts = 300 max_attempts = 300
delay_sec = 1 delay_sec = 1
if args.server: if args.server:
url = f"{args.server.rstrip('/')}/v1/models" # Check if the server URL already contains '/v1/openai' (DeepInfra case)
server_url = args.server.rstrip("/")
if "/v1/openai" in server_url:
url = f"{server_url}/models"
else:
url = f"{server_url}/v1/models"
else: else:
url = f"http://localhost:{BASE_SERVER_PORT}/v1/models" url = f"http://localhost:{BASE_SERVER_PORT}/v1/models"
for attempt in range(1, max_attempts + 1): for attempt in range(1, max_attempts + 1):
try: try:
headers = {}
if args.server and hasattr(args, "api_key") and args.api_key:
headers["Authorization"] = f"Bearer {args.api_key}"
async with httpx.AsyncClient() as session: async with httpx.AsyncClient() as session:
response = await session.get(url) response = await session.get(url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
logger.info("vllm server is ready.") logger.info("vllm server is ready.")
@ -1064,6 +1108,7 @@ async def main():
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288) parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288)
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1) parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs") parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
parser.add_argument("--api_key", type=str, default=None, help="API key for authenticated remote servers (e.g., DeepInfra)")
vllm_group = parser.add_argument_group( vllm_group = parser.add_argument_group(
"VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM." "VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM."

View File

@ -209,7 +209,7 @@ class TestRotationCorrection:
# Counter to track number of API calls # Counter to track number of API calls
call_count = 0 call_count = 0
async def mock_apost(url, json_data): async def mock_apost(url, json_data, api_key=None):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@ -268,9 +268,9 @@ This is the corrected text from the document."""
build_page_query_calls = [] build_page_query_calls = []
original_build_page_query = build_page_query original_build_page_query = build_page_query
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"):
build_page_query_calls.append(image_rotation) build_page_query_calls.append(image_rotation)
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name)
with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.apost", side_effect=mock_apost):
with patch("olmocr.pipeline.tracker", mock_tracker): with patch("olmocr.pipeline.tracker", mock_tracker):
@ -311,7 +311,7 @@ This is the corrected text from the document."""
# Counter to track number of API calls # Counter to track number of API calls
call_count = 0 call_count = 0
async def mock_apost(url, json_data): async def mock_apost(url, json_data, api_key=None):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@ -376,9 +376,9 @@ Document is now correctly oriented after 180 degree rotation."""
build_page_query_calls = [] build_page_query_calls = []
original_build_page_query = build_page_query original_build_page_query = build_page_query
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"):
build_page_query_calls.append(image_rotation) build_page_query_calls.append(image_rotation)
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name)
with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.apost", side_effect=mock_apost):
with patch("olmocr.pipeline.tracker", mock_tracker): with patch("olmocr.pipeline.tracker", mock_tracker):
@ -420,7 +420,7 @@ Document is now correctly oriented after 180 degree rotation."""
# Counter to track number of API calls # Counter to track number of API calls
call_count = 0 call_count = 0
async def mock_apost(url, json_data): async def mock_apost(url, json_data, api_key=None):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@ -482,9 +482,9 @@ Document correctly oriented at 90 degrees total rotation."""
build_page_query_calls = [] build_page_query_calls = []
original_build_page_query = build_page_query original_build_page_query = build_page_query
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"):
build_page_query_calls.append(image_rotation) build_page_query_calls.append(image_rotation)
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name)
with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.apost", side_effect=mock_apost):
with patch("olmocr.pipeline.tracker", mock_tracker): with patch("olmocr.pipeline.tracker", mock_tracker):