Merge branch 'main' into jakep/new_data

This commit is contained in:
Jake Poznanski 2025-10-07 17:44:54 +00:00
commit 8ef68fde88
9 changed files with 203 additions and 89 deletions

View File

@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased ## Unreleased
## [v0.3.8](https://github.com/allenai/olmocr/releases/tag/v0.3.8) - 2025-10-06
## [v0.3.7](https://github.com/allenai/olmocr/releases/tag/v0.3.7) - 2025-10-06
## [v0.3.6](https://github.com/allenai/olmocr/releases/tag/v0.3.6) - 2025-09-29
## [v0.3.4](https://github.com/allenai/olmocr/releases/tag/v0.3.4) - 2025-08-31
## [v0.3.3](https://github.com/allenai/olmocr/releases/tag/v0.3.3) - 2025-08-15 ## [v0.3.3](https://github.com/allenai/olmocr/releases/tag/v0.3.3) - 2025-08-15
## [v0.3.2](https://github.com/allenai/olmocr/releases/tag/v0.3.2) - 2025-08-14 ## [v0.3.2](https://github.com/allenai/olmocr/releases/tag/v0.3.2) - 2025-08-14

View File

@ -1,62 +1,53 @@
ARG CUDA_VERSION=12.8.1 FROM vllm/vllm-openai:v0.11.0
FROM --platform=linux/amd64 nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
# Needs to be repeated below the FROM, or else it's not picked up ENV PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.12 ENV CUSTOM_PY="/usr/bin/python${PYTHON_VERSION}"
ARG CUDA_VERSION=12.8.1
# Set environment variable to prevent interactive prompts # Workaround for installing fonts, which are needed for good rendering of documents
ENV DEBIAN_FRONTEND=noninteractive RUN DIST_PY=$(ls /usr/bin/python3.[0-9]* | sort -V | head -n1) && \
# If a python alternative scheme already exists, remember its value so we \
# From original VLLM dockerfile https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile # can restore it later; otherwise, we will restore to CUSTOM_PY when we \
# Install Python and other dependencies # are done. \
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ if update-alternatives --query python3 >/dev/null 2>&1; then \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ ORIGINAL_PY=$(update-alternatives --query python3 | awk -F": " '/Value:/ {print $2}'); \
&& apt-get update -y \ else \
&& apt-get install -y ccache software-properties-common git curl sudo python3-apt \ ORIGINAL_PY=$CUSTOM_PY; \
&& for i in 1 2 3; do \ fi && \
add-apt-repository -y ppa:deadsnakes/ppa && break || \ # ---- APT operations that require the distro python3 ------------------- \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ echo "Temporarily switching python3 alternative to ${DIST_PY} so that APT scripts use the distrobuilt Python runtime." && \
done \ update-alternatives --install /usr/bin/python3 python3 ${DIST_PY} 1 && \
&& apt-get update -y \ update-alternatives --set python3 ${DIST_PY} && \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv update-alternatives --install /usr/bin/python python ${DIST_PY} 1 && \
update-alternatives --set python ${DIST_PY} && \
# olmOCR Specific Installs - Install fonts BEFORE changing Python version
RUN echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections && \
apt-get update -y && \ apt-get update -y && \
apt-get install -y --no-install-recommends poppler-utils fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools ttf-mscorefonts-installer apt-get remove -y python3-blinker || true && \
# Preseed the Microsoft Core Fonts EULA so the build is noninteractive \
# Now update Python alternatives echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections && \
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ python3-apt \
&& update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1 \ update-notifier-common \
&& update-alternatives --set python /usr/bin/python${PYTHON_VERSION} \ poppler-utils \
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ fonts-crosextra-caladea \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ fonts-crosextra-carlito \
&& python3 --version && python3 -m pip --version gsfonts \
lcdf-typetools \
# Install uv for faster pip installs ttf-mscorefonts-installer \
RUN --mount=type=cache,target=/root/.cache/uv python3 -m pip install uv git git-lfs curl wget unzip && \
# ---- Restore the original / custom Python alternative ----------------- \
# Install some helper utilities for things like the benchmark echo "Restoring python3 alternative to ${ORIGINAL_PY}" && \
RUN apt-get update -y && apt-get install -y --no-install-recommends \ update-alternatives --install /usr/bin/python3 python3 ${ORIGINAL_PY} 1 && \
git \ update-alternatives --set python3 ${ORIGINAL_PY} && \
git-lfs \ update-alternatives --install /usr/bin/python python ${ORIGINAL_PY} 1 || true && \
curl \ update-alternatives --set python ${ORIGINAL_PY} || true && \
wget \ # Ensure pip is available for the restored Python \
unzip curl -sS https://bootstrap.pypa.io/get-pip.py | ${ORIGINAL_PY}
ENV PYTHONUNBUFFERED=1
# keep the build context clean # keep the build context clean
WORKDIR /build WORKDIR /build
COPY . /build COPY . /build
# Needed to resolve setuptools dependencies # Needed to resolve setuptools dependencies
ENV UV_INDEX_STRATEGY="unsafe-best-match" ENV UV_INDEX_STRATEGY="unsafe-best-match"
RUN uv pip install --system --no-cache ".[gpu]" --extra-index-url https://download.pytorch.org/whl/cu128
RUN uv pip install --system https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl
RUN uv pip install --system --no-cache ".[bench]" RUN uv pip install --system --no-cache ".[bench]"
RUN playwright install-deps RUN playwright install-deps

View File

@ -209,6 +209,43 @@ python -m olmocr.pipeline ./localworkspace --markdown --pdfs tests/gnarly_pdfs/*
With the addition of the `--markdown` flag, results will be stored as markdown files inside of `./localworkspace/markdown/`. With the addition of the `--markdown` flag, results will be stored as markdown files inside of `./localworkspace/markdown/`.
### Using External vLLM Server
If you have a vLLM server already running elsewhere (or any inference platform implementing the relevant subset of the OpenAI API), you can point olmOCR to use it instead of spawning a local instance:
```bash
# Use external vLLM server instead of local one
python -m olmocr.pipeline ./localworkspace --server http://remote-server:8000 --markdown --pdfs tests/gnarly_pdfs/*.pdf
```
The served model name should be `olmocr`. An example vLLM launch command would be:
```bash
vllm serve allenai/olmOCR-7B-0825-FP8 --served-model-name olmocr --max-model-len 16384
```
#### Run olmOCR with the DeepInfra server endpoint:
Signup at [DeepInfra](https://deepinfra.com/) and get your API key from the DeepInfra dashboard.
Store the API key as an environment variable.
```bash
export DEEPINFRA_API_KEY="your-api-key-here"
```
```bash
python -m olmocr.pipeline ./localworkspace \
--server https://api.deepinfra.com/v1/openai \
--api_key $DEEPINFRA_API_KEY \
--pages_per_group 100 \
--model allenai/olmOCR-7B-0825 \
--markdown \
--pdfs path/to/your/*.pdf
```
- `--server`: DeepInfra's OpenAI-compatible endpoint: `https://api.deepinfra.com/v1/openai`
- `--api_key`: Your DeepInfra API key
- `--pages_per_group`: You may want a smaller number of pages per group as many external provides have lower concurrent request limits
- `--model`: The model identifier on DeepInfra: `allenai/olmOCR-7B-0825`
- Other arguments work the same as with local inference
#### Viewing Results #### Viewing Results
The `./localworkspace/` workspace folder will then have both [Dolma](https://github.com/allenai/dolma) and markdown files (if using `--markdown`). The `./localworkspace/` workspace folder will then have both [Dolma](https://github.com/allenai/dolma) and markdown files (if using `--markdown`).
@ -249,6 +286,7 @@ For example:
python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace --pdfs s3://my_s3_bucket/jakep/gnarly_pdfs/*.pdf --beaker --beaker_gpus 4 python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace --pdfs s3://my_s3_bucket/jakep/gnarly_pdfs/*.pdf --beaker --beaker_gpus 4
``` ```
### Using Docker ### Using Docker
Pull the Docker image. Pull the Docker image.
@ -284,7 +322,7 @@ python -m olmocr.pipeline ./localworkspace --markdown --pdfs olmocr-sample.pdf
python -m olmocr.pipeline --help python -m olmocr.pipeline --help
usage: pipeline.py [-h] [--pdfs [PDFS ...]] [--model MODEL] [--workspace_profile WORKSPACE_PROFILE] [--pdf_profile PDF_PROFILE] [--pages_per_group PAGES_PER_GROUP] [--max_page_retries MAX_PAGE_RETRIES] [--max_page_error_rate MAX_PAGE_ERROR_RATE] [--workers WORKERS] usage: pipeline.py [-h] [--pdfs [PDFS ...]] [--model MODEL] [--workspace_profile WORKSPACE_PROFILE] [--pdf_profile PDF_PROFILE] [--pages_per_group PAGES_PER_GROUP] [--max_page_retries MAX_PAGE_RETRIES] [--max_page_error_rate MAX_PAGE_ERROR_RATE] [--workers WORKERS]
[--apply_filter] [--stats] [--markdown] [--target_longest_image_dim TARGET_LONGEST_IMAGE_DIM] [--target_anchor_text_len TARGET_ANCHOR_TEXT_LEN] [--guided_decoding] [--gpu-memory-utilization GPU_MEMORY_UTILIZATION] [--max_model_len MAX_MODEL_LEN] [--apply_filter] [--stats] [--markdown] [--target_longest_image_dim TARGET_LONGEST_IMAGE_DIM] [--target_anchor_text_len TARGET_ANCHOR_TEXT_LEN] [--guided_decoding] [--gpu-memory-utilization GPU_MEMORY_UTILIZATION] [--max_model_len MAX_MODEL_LEN]
[--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--data-parallel-size DATA_PARALLEL_SIZE] [--port PORT] [--beaker] [--beaker_workspace BEAKER_WORKSPACE] [--beaker_cluster BEAKER_CLUSTER] [--beaker_gpus BEAKER_GPUS] [--beaker_priority BEAKER_PRIORITY] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--data-parallel-size DATA_PARALLEL_SIZE] [--port PORT] [--server SERVER] [--beaker] [--beaker_workspace BEAKER_WORKSPACE] [--beaker_cluster BEAKER_CLUSTER] [--beaker_gpus BEAKER_GPUS] [--beaker_priority BEAKER_PRIORITY]
workspace workspace
Manager for running millions of PDFs through a batch inference pipeline Manager for running millions of PDFs through a batch inference pipeline
@ -316,7 +354,7 @@ options:
Maximum amount of anchor text to use (characters), not used for new models Maximum amount of anchor text to use (characters), not used for new models
--guided_decoding Enable guided decoding for model YAML type outputs --guided_decoding Enable guided decoding for model YAML type outputs
VLLM Forwarded arguments: VLLM arguments:
--gpu-memory-utilization GPU_MEMORY_UTILIZATION --gpu-memory-utilization GPU_MEMORY_UTILIZATION
Fraction of VRAM vLLM may pre-allocate for KV-cache (passed through to vllm serve). Fraction of VRAM vLLM may pre-allocate for KV-cache (passed through to vllm serve).
--max_model_len MAX_MODEL_LEN --max_model_len MAX_MODEL_LEN
@ -326,6 +364,9 @@ VLLM Forwarded arguments:
--data-parallel-size DATA_PARALLEL_SIZE, -dp DATA_PARALLEL_SIZE --data-parallel-size DATA_PARALLEL_SIZE, -dp DATA_PARALLEL_SIZE
Data parallel size for vLLM Data parallel size for vLLM
--port PORT Port to use for the VLLM server --port PORT Port to use for the VLLM server
--server SERVER URL of external vLLM (or other compatible provider)
server (e.g., http://hostname:port). If provided,
skips spawning local vLLM instance
beaker/cluster execution: beaker/cluster execution:
--beaker Submit this job to beaker instead of running locally --beaker Submit this job to beaker instead of running locally

View File

@ -130,7 +130,7 @@ def normalize_text(md_content: str) -> str:
md_content = re.sub(r"\*(.*?)\*", r"\1", md_content) md_content = re.sub(r"\*(.*?)\*", r"\1", md_content)
md_content = re.sub(r"_(.*?)_", r"\1", md_content) md_content = re.sub(r"_(.*?)_", r"\1", md_content)
# Convert down to a consistent unicode form, so == e + accent, unicode forms # Convert down to a consistent unicode form, so é == e + accent, unicode forms
md_content = unicodedata.normalize("NFC", md_content) md_content = unicodedata.normalize("NFC", md_content)
# Dictionary of characters to replace: keys are fancy characters, values are ASCII equivalents, unicode micro with greek mu comes up often enough too # Dictionary of characters to replace: keys are fancy characters, values are ASCII equivalents, unicode micro with greek mu comes up often enough too

View File

@ -11,6 +11,7 @@ import os
import random import random
import re import re
import shutil import shutil
import ssl
import sys import sys
import tempfile import tempfile
import time import time
@ -132,7 +133,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return { return {
"model": "olmocr", "model": model_name,
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@ -151,29 +152,44 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_
# It feels strange perhaps, but httpx and aiohttp are very complex beasts # It feels strange perhaps, but httpx and aiohttp are very complex beasts
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed # Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
# that at the scale of 100M+ requests, that they deadlock in different strange ways # that at the scale of 100M+ requests, that they deadlock in different strange ways
async def apost(url, json_data): async def apost(url, json_data, api_key=None):
parsed_url = urlparse(url) parsed_url = urlparse(url)
host = parsed_url.hostname host = parsed_url.hostname
port = parsed_url.port or 80 # Default to 443 for HTTPS, 80 for HTTP
if parsed_url.scheme == "https":
port = parsed_url.port or 443
use_ssl = True
else:
port = parsed_url.port or 80
use_ssl = False
path = parsed_url.path or "/" path = parsed_url.path or "/"
writer = None writer = None
try: try:
reader, writer = await asyncio.open_connection(host, port) if use_ssl:
ssl_context = ssl.create_default_context()
reader, writer = await asyncio.open_connection(host, port, ssl=ssl_context)
else:
reader, writer = await asyncio.open_connection(host, port)
json_payload = json.dumps(json_data) json_payload = json.dumps(json_data)
request = (
f"POST {path} HTTP/1.1\r\n" headers = [
f"Host: {host}\r\n" f"POST {path} HTTP/1.1",
f"Content-Type: application/json\r\n" f"Host: {host}",
f"Content-Length: {len(json_payload)}\r\n" f"Content-Type: application/json",
f"Connection: close\r\n\r\n" f"Content-Length: {len(json_payload)}",
f"{json_payload}" ]
)
if api_key:
headers.append(f"Authorization: Bearer {api_key}")
headers.append("Connection: close")
request = "\r\n".join(headers) + "\r\n\r\n" + json_payload
writer.write(request.encode()) writer.write(request.encode())
await writer.drain() await writer.drain()
# Read status line
status_line = await reader.readline() status_line = await reader.readline()
if not status_line: if not status_line:
raise ConnectionError("No response from server") raise ConnectionError("No response from server")
@ -213,7 +229,16 @@ async def apost(url, json_data):
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult: async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions" if args.server:
server_url = args.server.rstrip("/")
# Check if the server URL already contains '/v1/openai' (DeepInfra case)
if "/v1/openai" in server_url:
COMPLETION_URL = f"{server_url}/chat/completions"
else:
COMPLETION_URL = f"{server_url}/v1/chat/completions"
logger.debug(f"Using completion URL: {COMPLETION_URL}")
else:
COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries MAX_RETRIES = args.max_page_retries
MODEL_MAX_CONTEXT = 16384 MODEL_MAX_CONTEXT = 16384
TEMPERATURE_BY_ATTEMPT = [0.1, 0.1, 0.2, 0.3, 0.5, 0.8, 0.9, 1.0] TEMPERATURE_BY_ATTEMPT = [0.1, 0.1, 0.2, 0.3, 0.5, 0.8, 0.9, 1.0]
@ -224,11 +249,19 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
while attempt < MAX_RETRIES: while attempt < MAX_RETRIES:
lookup_attempt = min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1) lookup_attempt = min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1)
# For external servers (like DeepInfra), use the model name from args
# For local inference, always use 'olmocr'
if args.server and hasattr(args, "model"):
model_name = args.model
else:
model_name = "olmocr"
query = await build_page_query( query = await build_page_query(
pdf_local_path, pdf_local_path,
page_num, page_num,
args.target_longest_image_dim, args.target_longest_image_dim,
image_rotation=cumulative_rotation, image_rotation=cumulative_rotation,
model_name=model_name,
) )
# Change temperature as number of attempts increases to overcome repetition issues at expense of quality # Change temperature as number of attempts increases to overcome repetition issues at expense of quality
query["temperature"] = TEMPERATURE_BY_ATTEMPT[lookup_attempt] query["temperature"] = TEMPERATURE_BY_ATTEMPT[lookup_attempt]
@ -242,10 +275,17 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
logger.debug(f"Built page query for {pdf_orig_path}-{page_num}") logger.debug(f"Built page query for {pdf_orig_path}-{page_num}")
try: try:
status_code, response_body = await apost(COMPLETION_URL, json_data=query) # Passing API key only for external servers that need authentication
if args.server and hasattr(args, "api_key"):
api_key = args.api_key
else:
api_key = None
status_code, response_body = await apost(COMPLETION_URL, json_data=query, api_key=api_key)
if status_code == 400: if status_code == 400:
raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response") raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response")
elif status_code == 429:
raise ConnectionError(f"Too many requests, doing exponential backoff")
elif status_code == 500: elif status_code == 500:
raise ValueError(f"Got InternalServerError from server: {response_body}, skipping this response") raise ValueError(f"Got InternalServerError from server: {response_body}, skipping this response")
elif status_code != 200: elif status_code != 200:
@ -596,6 +636,8 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non
str(args.tensor_parallel_size), str(args.tensor_parallel_size),
"--data-parallel-size", "--data-parallel-size",
str(args.data_parallel_size), str(args.data_parallel_size),
"--limit-mm-per-prompt",
'{"video": 0}', # Disabling video encoder saves RAM that you can put towards the KV cache, thanks @charitarthchugh
] ]
if args.gpu_memory_utilization is not None: if args.gpu_memory_utilization is not None:
@ -730,15 +772,27 @@ async def vllm_server_host(model_name_or_path, args, semaphore, unknown_args=Non
sys.exit(1) sys.exit(1)
async def vllm_server_ready(): async def vllm_server_ready(args):
max_attempts = 300 max_attempts = 300
delay_sec = 1 delay_sec = 1
url = f"http://localhost:{BASE_SERVER_PORT}/v1/models" if args.server:
# Check if the server URL already contains '/v1/openai' (DeepInfra case)
server_url = args.server.rstrip("/")
if "/v1/openai" in server_url:
url = f"{server_url}/models"
else:
url = f"{server_url}/v1/models"
else:
url = f"http://localhost:{BASE_SERVER_PORT}/v1/models"
for attempt in range(1, max_attempts + 1): for attempt in range(1, max_attempts + 1):
try: try:
headers = {}
if args.server and hasattr(args, "api_key") and args.api_key:
headers["Authorization"] = f"Bearer {args.api_key}"
async with httpx.AsyncClient() as session: async with httpx.AsyncClient() as session:
response = await session.get(url) response = await session.get(url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
logger.info("vllm server is ready.") logger.info("vllm server is ready.")
@ -1058,6 +1112,7 @@ async def main():
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288) parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288)
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1) parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs") parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
parser.add_argument("--api_key", type=str, default=None, help="API key for authenticated remote servers (e.g., DeepInfra)")
vllm_group = parser.add_argument_group( vllm_group = parser.add_argument_group(
"VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM." "VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM."
@ -1069,6 +1124,11 @@ async def main():
vllm_group.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="Tensor parallel size for vLLM") vllm_group.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="Tensor parallel size for vLLM")
vllm_group.add_argument("--data-parallel-size", "-dp", type=int, default=1, help="Data parallel size for vLLM") vllm_group.add_argument("--data-parallel-size", "-dp", type=int, default=1, help="Data parallel size for vLLM")
vllm_group.add_argument("--port", type=int, default=30024, help="Port to use for the VLLM server") vllm_group.add_argument("--port", type=int, default=30024, help="Port to use for the VLLM server")
vllm_group.add_argument(
"--server",
type=str,
help="URL of external vLLM (or other compatible provider) server (e.g., http://hostname:port). If provided, skips spawning local vLLM instance",
)
# Beaker/job running stuff # Beaker/job running stuff
beaker_group = parser.add_argument_group("beaker/cluster execution") beaker_group = parser.add_argument_group("beaker/cluster execution")
@ -1207,12 +1267,17 @@ async def main():
# If you get this far, then you are doing inference and need a GPU # If you get this far, then you are doing inference and need a GPU
# check_sglang_version() # check_sglang_version()
check_torch_gpu_available() if not args.server:
check_torch_gpu_available()
logger.info(f"Starting pipeline with PID {os.getpid()}") logger.info(f"Starting pipeline with PID {os.getpid()}")
# Download the model before you do anything else # Download the model before you do anything else
model_name_or_path = await download_model(args.model) if args.server:
logger.info(f"Using external server at {args.server}")
model_name_or_path = None
else:
model_name_or_path = await download_model(args.model)
# Initialize the work queue # Initialize the work queue
qsize = await work_queue.initialize_queue() qsize = await work_queue.initialize_queue()
@ -1226,9 +1291,12 @@ async def main():
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests # As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1) semaphore = asyncio.Semaphore(1)
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args)) # Start local vLLM instance if not using external one
vllm_server = None
if not args.server:
vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args))
await vllm_server_ready() await vllm_server_ready(args)
metrics_task = asyncio.create_task(metrics_reporter(work_queue)) metrics_task = asyncio.create_task(metrics_reporter(work_queue))
@ -1241,11 +1309,16 @@ async def main():
# Wait for all worker tasks to finish # Wait for all worker tasks to finish
await asyncio.gather(*worker_tasks) await asyncio.gather(*worker_tasks)
vllm_server.cancel() # Cancel vLLM server if it was started
if vllm_server is not None:
vllm_server.cancel()
metrics_task.cancel() metrics_task.cancel()
# Wait for cancelled tasks to complete # Wait for cancelled tasks to complete
await asyncio.gather(vllm_server, metrics_task, return_exceptions=True) tasks_to_wait = [metrics_task]
if vllm_server is not None:
tasks_to_wait.append(vllm_server)
await asyncio.gather(*tasks_to_wait, return_exceptions=True)
# Output final metrics summary # Output final metrics summary
metrics_summary = metrics.get_metrics_summary() metrics_summary = metrics.get_metrics_summary()

View File

@ -2,7 +2,7 @@ _MAJOR = "0"
_MINOR = "3" _MINOR = "3"
# On main and in a nightly release the patch should be one ahead of the last # On main and in a nightly release the patch should be one ahead of the last
# released build. # released build.
_PATCH = "3" _PATCH = "8"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See # This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics. # https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = "" _SUFFIX = ""

View File

@ -37,7 +37,7 @@ dependencies = [
"boto3", "boto3",
"httpx", "httpx",
"torch>=2.7.0", "torch>=2.7.0",
"transformers==4.53.2", "transformers==4.55.2",
"img2pdf", "img2pdf",
"beaker-py", "beaker-py",
] ]
@ -51,7 +51,7 @@ Changelog = "https://github.com/allenai/olmocr/blob/main/CHANGELOG.md"
[project.optional-dependencies] [project.optional-dependencies]
gpu = [ gpu = [
"vllm==0.10.0" "vllm==0.11.0"
] ]
dev = [ dev = [

View File

@ -68,7 +68,7 @@ read -p "Creating new release for $TAG. Do you want to continue? [Y/n] " prompt
if [[ $prompt == "y" || $prompt == "Y" || $prompt == "yes" || $prompt == "Yes" ]]; then if [[ $prompt == "y" || $prompt == "Y" || $prompt == "yes" || $prompt == "Yes" ]]; then
python scripts/prepare_changelog.py python scripts/prepare_changelog.py
git add -A git add CHANGELOG.md
git commit -m "Bump version to $TAG for release" || true && git push git commit -m "Bump version to $TAG for release" || true && git push
echo "Creating new git tag $TAG" echo "Creating new git tag $TAG"
git tag "$TAG" -m "$TAG" git tag "$TAG" -m "$TAG"

View File

@ -192,6 +192,7 @@ class MockArgs:
max_page_retries: int = 8 max_page_retries: int = 8
target_longest_image_dim: int = 1288 target_longest_image_dim: int = 1288
guided_decoding: bool = False guided_decoding: bool = False
server: str | None = None
class TestRotationCorrection: class TestRotationCorrection:
@ -208,7 +209,7 @@ class TestRotationCorrection:
# Counter to track number of API calls # Counter to track number of API calls
call_count = 0 call_count = 0
async def mock_apost(url, json_data): async def mock_apost(url, json_data, api_key=None):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@ -267,9 +268,9 @@ This is the corrected text from the document."""
build_page_query_calls = [] build_page_query_calls = []
original_build_page_query = build_page_query original_build_page_query = build_page_query
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"):
build_page_query_calls.append(image_rotation) build_page_query_calls.append(image_rotation)
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name)
with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.apost", side_effect=mock_apost):
with patch("olmocr.pipeline.tracker", mock_tracker): with patch("olmocr.pipeline.tracker", mock_tracker):
@ -310,7 +311,7 @@ This is the corrected text from the document."""
# Counter to track number of API calls # Counter to track number of API calls
call_count = 0 call_count = 0
async def mock_apost(url, json_data): async def mock_apost(url, json_data, api_key=None):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@ -375,9 +376,9 @@ Document is now correctly oriented after 180 degree rotation."""
build_page_query_calls = [] build_page_query_calls = []
original_build_page_query = build_page_query original_build_page_query = build_page_query
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"):
build_page_query_calls.append(image_rotation) build_page_query_calls.append(image_rotation)
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name)
with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.apost", side_effect=mock_apost):
with patch("olmocr.pipeline.tracker", mock_tracker): with patch("olmocr.pipeline.tracker", mock_tracker):
@ -419,7 +420,7 @@ Document is now correctly oriented after 180 degree rotation."""
# Counter to track number of API calls # Counter to track number of API calls
call_count = 0 call_count = 0
async def mock_apost(url, json_data): async def mock_apost(url, json_data, api_key=None):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@ -481,9 +482,9 @@ Document correctly oriented at 90 degrees total rotation."""
build_page_query_calls = [] build_page_query_calls = []
original_build_page_query = build_page_query original_build_page_query = build_page_query
async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"):
build_page_query_calls.append(image_rotation) build_page_query_calls.append(image_rotation)
return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name)
with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.apost", side_effect=mock_apost):
with patch("olmocr.pipeline.tracker", mock_tracker): with patch("olmocr.pipeline.tracker", mock_tracker):