mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-05 19:51:28 +00:00
commit
f0caa188ab
20
README.md
20
README.md
@ -249,6 +249,26 @@ For example:
|
|||||||
```bash
|
```bash
|
||||||
python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace --pdfs s3://my_s3_bucket/jakep/gnarly_pdfs/*.pdf --beaker --beaker_gpus 4
|
python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace --pdfs s3://my_s3_bucket/jakep/gnarly_pdfs/*.pdf --beaker --beaker_gpus 4
|
||||||
```
|
```
|
||||||
|
### Using DeepInfra
|
||||||
|
Signup at [DeepInfra](https://deepinfra.com/) and get your API key from the DeepInfra dashboard.
|
||||||
|
Store the API key as an environment variable.
|
||||||
|
```bash
|
||||||
|
export DEEPINFRA_API_KEY="your-api-key-here"
|
||||||
|
```
|
||||||
|
#### Run olmOCR with the DeepInfra server endpoint:
|
||||||
|
```bash
|
||||||
|
python -m olmocr.pipeline ./localworkspace \
|
||||||
|
--server https://api.deepinfra.com/v1/openai \
|
||||||
|
--api_key $DEEPINFRA_API_KEY \
|
||||||
|
--model allenai/olmOCR-7B-0725-FP8 \
|
||||||
|
--markdown \
|
||||||
|
--pdfs path/to/your/*.pdf
|
||||||
|
```
|
||||||
|
- `--server`: DeepInfra's OpenAI-compatible endpoint: `https://api.deepinfra.com/v1/openai`
|
||||||
|
- `--api_key`: Your DeepInfra API key
|
||||||
|
- `--model`: The model identifier on DeepInfra: `allenai/olmOCR-7B-0725-FP8`
|
||||||
|
- Other arguments work the same as with local inference
|
||||||
|
|
||||||
|
|
||||||
### Using Docker
|
### Using Docker
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import ssl
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
@ -104,7 +105,7 @@ class PageResult:
|
|||||||
is_fallback: bool
|
is_fallback: bool
|
||||||
|
|
||||||
|
|
||||||
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, image_rotation: int = 0) -> dict:
|
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, image_rotation: int = 0, model_name: str = "olmocr") -> dict:
|
||||||
MAX_TOKENS = 4500
|
MAX_TOKENS = 4500
|
||||||
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
|
||||||
|
|
||||||
@ -132,7 +133,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
|
|||||||
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": "olmocr",
|
"model": model_name,
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@ -151,29 +152,44 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
|
|||||||
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
|
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
|
||||||
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
|
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
|
||||||
# that at the scale of 100M+ requests, that they deadlock in different strange ways
|
# that at the scale of 100M+ requests, that they deadlock in different strange ways
|
||||||
async def apost(url, json_data):
|
async def apost(url, json_data, api_key=None):
|
||||||
parsed_url = urlparse(url)
|
parsed_url = urlparse(url)
|
||||||
host = parsed_url.hostname
|
host = parsed_url.hostname
|
||||||
port = parsed_url.port or 80
|
# Default to 443 for HTTPS, 80 for HTTP
|
||||||
|
if parsed_url.scheme == "https":
|
||||||
|
port = parsed_url.port or 443
|
||||||
|
use_ssl = True
|
||||||
|
else:
|
||||||
|
port = parsed_url.port or 80
|
||||||
|
use_ssl = False
|
||||||
path = parsed_url.path or "/"
|
path = parsed_url.path or "/"
|
||||||
|
|
||||||
writer = None
|
writer = None
|
||||||
try:
|
try:
|
||||||
reader, writer = await asyncio.open_connection(host, port)
|
if use_ssl:
|
||||||
|
ssl_context = ssl.create_default_context()
|
||||||
|
reader, writer = await asyncio.open_connection(host, port, ssl=ssl_context)
|
||||||
|
else:
|
||||||
|
reader, writer = await asyncio.open_connection(host, port)
|
||||||
|
|
||||||
json_payload = json.dumps(json_data)
|
json_payload = json.dumps(json_data)
|
||||||
request = (
|
|
||||||
f"POST {path} HTTP/1.1\r\n"
|
headers = [
|
||||||
f"Host: {host}\r\n"
|
f"POST {path} HTTP/1.1",
|
||||||
f"Content-Type: application/json\r\n"
|
f"Host: {host}",
|
||||||
f"Content-Length: {len(json_payload)}\r\n"
|
f"Content-Type: application/json",
|
||||||
f"Connection: close\r\n\r\n"
|
f"Content-Length: {len(json_payload)}",
|
||||||
f"{json_payload}"
|
]
|
||||||
)
|
|
||||||
|
if api_key:
|
||||||
|
headers.append(f"Authorization: Bearer {api_key}")
|
||||||
|
|
||||||
|
headers.append("Connection: close")
|
||||||
|
|
||||||
|
request = "\r\n".join(headers) + "\r\n\r\n" + json_payload
|
||||||
writer.write(request.encode())
|
writer.write(request.encode())
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
|
|
||||||
# Read status line
|
|
||||||
status_line = await reader.readline()
|
status_line = await reader.readline()
|
||||||
if not status_line:
|
if not status_line:
|
||||||
raise ConnectionError("No response from server")
|
raise ConnectionError("No response from server")
|
||||||
@ -214,7 +230,13 @@ async def apost(url, json_data):
|
|||||||
|
|
||||||
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
|
||||||
if args.server:
|
if args.server:
|
||||||
COMPLETION_URL = f"{args.server.rstrip('/')}/v1/chat/completions"
|
server_url = args.server.rstrip("/")
|
||||||
|
# Check if the server URL already contains '/v1/openai' (DeepInfra case)
|
||||||
|
if "/v1/openai" in server_url:
|
||||||
|
COMPLETION_URL = f"{server_url}/chat/completions"
|
||||||
|
else:
|
||||||
|
COMPLETION_URL = f"{server_url}/v1/chat/completions"
|
||||||
|
logger.debug(f"Using completion URL: {COMPLETION_URL}")
|
||||||
else:
|
else:
|
||||||
COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions"
|
COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions"
|
||||||
MAX_RETRIES = args.max_page_retries
|
MAX_RETRIES = args.max_page_retries
|
||||||
@ -227,11 +249,19 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
|
|||||||
|
|
||||||
while attempt < MAX_RETRIES:
|
while attempt < MAX_RETRIES:
|
||||||
lookup_attempt = min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
|
lookup_attempt = min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
|
||||||
|
# For external servers (like DeepInfra), use the model name from args
|
||||||
|
# For local inference, always use 'olmocr'
|
||||||
|
if args.server and hasattr(args, "model"):
|
||||||
|
model_name = args.model
|
||||||
|
else:
|
||||||
|
model_name = "olmocr"
|
||||||
|
|
||||||
query = await build_page_query(
|
query = await build_page_query(
|
||||||
pdf_local_path,
|
pdf_local_path,
|
||||||
page_num,
|
page_num,
|
||||||
args.target_longest_image_dim,
|
args.target_longest_image_dim,
|
||||||
image_rotation=cumulative_rotation,
|
image_rotation=cumulative_rotation,
|
||||||
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
# Change temperature as number of attempts increases to overcome repetition issues at expense of quality
|
# Change temperature as number of attempts increases to overcome repetition issues at expense of quality
|
||||||
query["temperature"] = TEMPERATURE_BY_ATTEMPT[lookup_attempt]
|
query["temperature"] = TEMPERATURE_BY_ATTEMPT[lookup_attempt]
|
||||||
@ -245,7 +275,12 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
|
|||||||
logger.debug(f"Built page query for {pdf_orig_path}-{page_num}")
|
logger.debug(f"Built page query for {pdf_orig_path}-{page_num}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
status_code, response_body = await apost(COMPLETION_URL, json_data=query)
|
# Passing API key only for external servers that need authentication
|
||||||
|
if args.server and hasattr(args, "api_key"):
|
||||||
|
api_key = args.api_key
|
||||||
|
else:
|
||||||
|
api_key = None
|
||||||
|
status_code, response_body = await apost(COMPLETION_URL, json_data=query, api_key=api_key)
|
||||||
|
|
||||||
if status_code == 400:
|
if status_code == 400:
|
||||||
raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response")
|
raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response")
|
||||||
@ -737,14 +772,23 @@ async def vllm_server_ready(args):
|
|||||||
max_attempts = 300
|
max_attempts = 300
|
||||||
delay_sec = 1
|
delay_sec = 1
|
||||||
if args.server:
|
if args.server:
|
||||||
url = f"{args.server.rstrip('/')}/v1/models"
|
# Check if the server URL already contains '/v1/openai' (DeepInfra case)
|
||||||
|
server_url = args.server.rstrip("/")
|
||||||
|
if "/v1/openai" in server_url:
|
||||||
|
url = f"{server_url}/models"
|
||||||
|
else:
|
||||||
|
url = f"{server_url}/v1/models"
|
||||||
else:
|
else:
|
||||||
url = f"http://localhost:{BASE_SERVER_PORT}/v1/models"
|
url = f"http://localhost:{BASE_SERVER_PORT}/v1/models"
|
||||||
|
|
||||||
for attempt in range(1, max_attempts + 1):
|
for attempt in range(1, max_attempts + 1):
|
||||||
try:
|
try:
|
||||||
|
headers = {}
|
||||||
|
if args.server and hasattr(args, "api_key") and args.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {args.api_key}"
|
||||||
|
|
||||||
async with httpx.AsyncClient() as session:
|
async with httpx.AsyncClient() as session:
|
||||||
response = await session.get(url)
|
response = await session.get(url, headers=headers)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
logger.info("vllm server is ready.")
|
logger.info("vllm server is ready.")
|
||||||
@ -1064,6 +1108,7 @@ async def main():
|
|||||||
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288)
|
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288)
|
||||||
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
|
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
|
||||||
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
|
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
|
||||||
|
parser.add_argument("--api_key", type=str, default=None, help="API key for authenticated remote servers (e.g., DeepInfra)")
|
||||||
|
|
||||||
vllm_group = parser.add_argument_group(
|
vllm_group = parser.add_argument_group(
|
||||||
"VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM."
|
"VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM."
|
||||||
|
|||||||
@ -209,7 +209,7 @@ class TestRotationCorrection:
|
|||||||
# Counter to track number of API calls
|
# Counter to track number of API calls
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_apost(url, json_data):
|
async def mock_apost(url, json_data, api_key=None):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
|
|
||||||
@ -268,9 +268,9 @@ This is the corrected text from the document."""
|
|||||||
build_page_query_calls = []
|
build_page_query_calls = []
|
||||||
original_build_page_query = build_page_query
|
original_build_page_query = build_page_query
|
||||||
|
|
||||||
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0):
|
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"):
|
||||||
build_page_query_calls.append(image_rotation)
|
build_page_query_calls.append(image_rotation)
|
||||||
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation)
|
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name)
|
||||||
|
|
||||||
with patch("olmocr.pipeline.apost", side_effect=mock_apost):
|
with patch("olmocr.pipeline.apost", side_effect=mock_apost):
|
||||||
with patch("olmocr.pipeline.tracker", mock_tracker):
|
with patch("olmocr.pipeline.tracker", mock_tracker):
|
||||||
@ -311,7 +311,7 @@ This is the corrected text from the document."""
|
|||||||
# Counter to track number of API calls
|
# Counter to track number of API calls
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_apost(url, json_data):
|
async def mock_apost(url, json_data, api_key=None):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
|
|
||||||
@ -376,9 +376,9 @@ Document is now correctly oriented after 180 degree rotation."""
|
|||||||
build_page_query_calls = []
|
build_page_query_calls = []
|
||||||
original_build_page_query = build_page_query
|
original_build_page_query = build_page_query
|
||||||
|
|
||||||
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0):
|
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"):
|
||||||
build_page_query_calls.append(image_rotation)
|
build_page_query_calls.append(image_rotation)
|
||||||
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation)
|
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name)
|
||||||
|
|
||||||
with patch("olmocr.pipeline.apost", side_effect=mock_apost):
|
with patch("olmocr.pipeline.apost", side_effect=mock_apost):
|
||||||
with patch("olmocr.pipeline.tracker", mock_tracker):
|
with patch("olmocr.pipeline.tracker", mock_tracker):
|
||||||
@ -420,7 +420,7 @@ Document is now correctly oriented after 180 degree rotation."""
|
|||||||
# Counter to track number of API calls
|
# Counter to track number of API calls
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_apost(url, json_data):
|
async def mock_apost(url, json_data, api_key=None):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
|
|
||||||
@ -482,9 +482,9 @@ Document correctly oriented at 90 degrees total rotation."""
|
|||||||
build_page_query_calls = []
|
build_page_query_calls = []
|
||||||
original_build_page_query = build_page_query
|
original_build_page_query = build_page_query
|
||||||
|
|
||||||
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0):
|
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"):
|
||||||
build_page_query_calls.append(image_rotation)
|
build_page_query_calls.append(image_rotation)
|
||||||
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation)
|
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name)
|
||||||
|
|
||||||
with patch("olmocr.pipeline.apost", side_effect=mock_apost):
|
with patch("olmocr.pipeline.apost", side_effect=mock_apost):
|
||||||
with patch("olmocr.pipeline.tracker", mock_tracker):
|
with patch("olmocr.pipeline.tracker", mock_tracker):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user