diff --git a/CHANGELOG.md b/CHANGELOG.md index 86f22d1..a1d37d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +## [v0.3.8](https://github.com/allenai/olmocr/releases/tag/v0.3.8) - 2025-10-06 + +## [v0.3.7](https://github.com/allenai/olmocr/releases/tag/v0.3.7) - 2025-10-06 + +## [v0.3.6](https://github.com/allenai/olmocr/releases/tag/v0.3.6) - 2025-09-29 + +## [v0.3.4](https://github.com/allenai/olmocr/releases/tag/v0.3.4) - 2025-08-31 + ## [v0.3.3](https://github.com/allenai/olmocr/releases/tag/v0.3.3) - 2025-08-15 ## [v0.3.2](https://github.com/allenai/olmocr/releases/tag/v0.3.2) - 2025-08-14 diff --git a/Dockerfile b/Dockerfile index 2ac06b0..7db0fea 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,62 +1,53 @@ -ARG CUDA_VERSION=12.8.1 -FROM --platform=linux/amd64 nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 +FROM vllm/vllm-openai:v0.11.0 -# Needs to be repeated below the FROM, or else it's not picked up -ARG PYTHON_VERSION=3.12 -ARG CUDA_VERSION=12.8.1 +ENV PYTHON_VERSION=3.12 +ENV CUSTOM_PY="/usr/bin/python${PYTHON_VERSION}" -# Set environment variable to prevent interactive prompts -ENV DEBIAN_FRONTEND=noninteractive - -# From original VLLM dockerfile https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile -# Install Python and other dependencies -RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ - && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ - && apt-get update -y \ - && apt-get install -y ccache software-properties-common git curl sudo python3-apt \ - && for i in 1 2 3; do \ - add-apt-repository -y ppa:deadsnakes/ppa && break || \ - { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ - done \ - && apt-get update -y \ - && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv - -# olmOCR Specific Installs - Install fonts BEFORE changing Python version -RUN echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections && \ +# Workaround for installing fonts, which are needed for good rendering of documents +RUN DIST_PY=$(ls /usr/bin/python3.[0-9]* | sort -V | head -n1) && \ + # If a python alternative scheme already exists, remember its value so we \ + # can restore it later; otherwise, we will restore to CUSTOM_PY when we \ + # are done. \ + if update-alternatives --query python3 >/dev/null 2>&1; then \ + ORIGINAL_PY=$(update-alternatives --query python3 | awk -F": " '/Value:/ {print $2}'); \ + else \ + ORIGINAL_PY=$CUSTOM_PY; \ + fi && \ + # ---- APT operations that require the distro python3 ------------------- \ + echo "Temporarily switching python3 alternative to ${DIST_PY} so that APT scripts use the distro‑built Python runtime." && \ + update-alternatives --install /usr/bin/python3 python3 ${DIST_PY} 1 && \ + update-alternatives --set python3 ${DIST_PY} && \ + update-alternatives --install /usr/bin/python python ${DIST_PY} 1 && \ + update-alternatives --set python ${DIST_PY} && \ apt-get update -y && \ - apt-get install -y --no-install-recommends poppler-utils fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools ttf-mscorefonts-installer - -# Now update Python alternatives -RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ - && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ - && update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1 \ - && update-alternatives --set python /usr/bin/python${PYTHON_VERSION} \ - && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ - && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ - && python3 --version && python3 -m pip --version - -# Install uv for faster pip installs -RUN --mount=type=cache,target=/root/.cache/uv python3 -m pip install uv - -# Install some helper utilities for things like the benchmark -RUN apt-get update -y && apt-get install -y --no-install-recommends \ - git \ - git-lfs \ - curl \ - wget \ - unzip - -ENV PYTHONUNBUFFERED=1 + apt-get remove -y python3-blinker || true && \ + # Pre‑seed the Microsoft Core Fonts EULA so the build is non‑interactive \ + echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections && \ + DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + python3-apt \ + update-notifier-common \ + poppler-utils \ + fonts-crosextra-caladea \ + fonts-crosextra-carlito \ + gsfonts \ + lcdf-typetools \ + ttf-mscorefonts-installer \ + git git-lfs curl wget unzip && \ + # ---- Restore the original / custom Python alternative ----------------- \ + echo "Restoring python3 alternative to ${ORIGINAL_PY}" && \ + update-alternatives --install /usr/bin/python3 python3 ${ORIGINAL_PY} 1 && \ + update-alternatives --set python3 ${ORIGINAL_PY} && \ + update-alternatives --install /usr/bin/python python ${ORIGINAL_PY} 1 || true && \ + update-alternatives --set python ${ORIGINAL_PY} || true && \ + # Ensure pip is available for the restored Python \ + curl -sS https://bootstrap.pypa.io/get-pip.py | ${ORIGINAL_PY} # keep the build context clean WORKDIR /build COPY . /build - # Needed to resolve setuptools dependencies ENV UV_INDEX_STRATEGY="unsafe-best-match" -RUN uv pip install --system --no-cache ".[gpu]" --extra-index-url https://download.pytorch.org/whl/cu128 -RUN uv pip install --system https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl RUN uv pip install --system --no-cache ".[bench]" RUN playwright install-deps diff --git a/README.md b/README.md index 0c40d3d..335ca6f 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,43 @@ python -m olmocr.pipeline ./localworkspace --markdown --pdfs tests/gnarly_pdfs/* With the addition of the `--markdown` flag, results will be stored as markdown files inside of `./localworkspace/markdown/`. +### Using External vLLM Server + +If you have a vLLM server already running elsewhere (or any inference platform implementing the relevant subset of the OpenAI API), you can point olmOCR to use it instead of spawning a local instance: + +```bash +# Use external vLLM server instead of local one +python -m olmocr.pipeline ./localworkspace --server http://remote-server:8000 --markdown --pdfs tests/gnarly_pdfs/*.pdf +``` + +The served model name should be `olmocr`. An example vLLM launch command would be: +```bash +vllm serve allenai/olmOCR-7B-0825-FP8 --served-model-name olmocr --max-model-len 16384 +``` + +#### Run olmOCR with the DeepInfra server endpoint: +Signup at [DeepInfra](https://deepinfra.com/) and get your API key from the DeepInfra dashboard. +Store the API key as an environment variable. +```bash +export DEEPINFRA_API_KEY="your-api-key-here" +``` + +```bash +python -m olmocr.pipeline ./localworkspace \ + --server https://api.deepinfra.com/v1/openai \ + --api_key $DEEPINFRA_API_KEY \ + --pages_per_group 100 \ + --model allenai/olmOCR-7B-0825 \ + --markdown \ + --pdfs path/to/your/*.pdf +``` +- `--server`: DeepInfra's OpenAI-compatible endpoint: `https://api.deepinfra.com/v1/openai` +- `--api_key`: Your DeepInfra API key +- `--pages_per_group`: You may want a smaller number of pages per group as many external provides have lower concurrent request limits +- `--model`: The model identifier on DeepInfra: `allenai/olmOCR-7B-0825` +- Other arguments work the same as with local inference + + #### Viewing Results The `./localworkspace/` workspace folder will then have both [Dolma](https://github.com/allenai/dolma) and markdown files (if using `--markdown`). @@ -249,6 +286,7 @@ For example: python -m olmocr.pipeline s3://my_s3_bucket/pdfworkspaces/exampleworkspace --pdfs s3://my_s3_bucket/jakep/gnarly_pdfs/*.pdf --beaker --beaker_gpus 4 ``` + ### Using Docker Pull the Docker image. @@ -284,7 +322,7 @@ python -m olmocr.pipeline ./localworkspace --markdown --pdfs olmocr-sample.pdf python -m olmocr.pipeline --help usage: pipeline.py [-h] [--pdfs [PDFS ...]] [--model MODEL] [--workspace_profile WORKSPACE_PROFILE] [--pdf_profile PDF_PROFILE] [--pages_per_group PAGES_PER_GROUP] [--max_page_retries MAX_PAGE_RETRIES] [--max_page_error_rate MAX_PAGE_ERROR_RATE] [--workers WORKERS] [--apply_filter] [--stats] [--markdown] [--target_longest_image_dim TARGET_LONGEST_IMAGE_DIM] [--target_anchor_text_len TARGET_ANCHOR_TEXT_LEN] [--guided_decoding] [--gpu-memory-utilization GPU_MEMORY_UTILIZATION] [--max_model_len MAX_MODEL_LEN] - [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--data-parallel-size DATA_PARALLEL_SIZE] [--port PORT] [--beaker] [--beaker_workspace BEAKER_WORKSPACE] [--beaker_cluster BEAKER_CLUSTER] [--beaker_gpus BEAKER_GPUS] [--beaker_priority BEAKER_PRIORITY] + [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--data-parallel-size DATA_PARALLEL_SIZE] [--port PORT] [--server SERVER] [--beaker] [--beaker_workspace BEAKER_WORKSPACE] [--beaker_cluster BEAKER_CLUSTER] [--beaker_gpus BEAKER_GPUS] [--beaker_priority BEAKER_PRIORITY] workspace Manager for running millions of PDFs through a batch inference pipeline @@ -316,7 +354,7 @@ options: Maximum amount of anchor text to use (characters), not used for new models --guided_decoding Enable guided decoding for model YAML type outputs -VLLM Forwarded arguments: +VLLM arguments: --gpu-memory-utilization GPU_MEMORY_UTILIZATION Fraction of VRAM vLLM may pre-allocate for KV-cache (passed through to vllm serve). --max_model_len MAX_MODEL_LEN @@ -326,6 +364,9 @@ VLLM Forwarded arguments: --data-parallel-size DATA_PARALLEL_SIZE, -dp DATA_PARALLEL_SIZE Data parallel size for vLLM --port PORT Port to use for the VLLM server + --server SERVER URL of external vLLM (or other compatible provider) + server (e.g., http://hostname:port). If provided, + skips spawning local vLLM instance beaker/cluster execution: --beaker Submit this job to beaker instead of running locally diff --git a/olmocr/bench/tests.py b/olmocr/bench/tests.py index ef58550..3b96d93 100644 --- a/olmocr/bench/tests.py +++ b/olmocr/bench/tests.py @@ -130,7 +130,7 @@ def normalize_text(md_content: str) -> str: md_content = re.sub(r"\*(.*?)\*", r"\1", md_content) md_content = re.sub(r"_(.*?)_", r"\1", md_content) - # Convert down to a consistent unicode form, so é == e + accent, unicode forms + # Convert down to a consistent unicode form, so é == e + accent, unicode forms md_content = unicodedata.normalize("NFC", md_content) # Dictionary of characters to replace: keys are fancy characters, values are ASCII equivalents, unicode micro with greek mu comes up often enough too diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 75e0899..82316b3 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -11,6 +11,7 @@ import os import random import re import shutil +import ssl import sys import tempfile import time @@ -132,7 +133,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_ image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") return { - "model": "olmocr", + "model": model_name, "messages": [ { "role": "user", @@ -151,29 +152,44 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_ # It feels strange perhaps, but httpx and aiohttp are very complex beasts # Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed # that at the scale of 100M+ requests, that they deadlock in different strange ways -async def apost(url, json_data): +async def apost(url, json_data, api_key=None): parsed_url = urlparse(url) host = parsed_url.hostname - port = parsed_url.port or 80 + # Default to 443 for HTTPS, 80 for HTTP + if parsed_url.scheme == "https": + port = parsed_url.port or 443 + use_ssl = True + else: + port = parsed_url.port or 80 + use_ssl = False path = parsed_url.path or "/" writer = None try: - reader, writer = await asyncio.open_connection(host, port) + if use_ssl: + ssl_context = ssl.create_default_context() + reader, writer = await asyncio.open_connection(host, port, ssl=ssl_context) + else: + reader, writer = await asyncio.open_connection(host, port) json_payload = json.dumps(json_data) - request = ( - f"POST {path} HTTP/1.1\r\n" - f"Host: {host}\r\n" - f"Content-Type: application/json\r\n" - f"Content-Length: {len(json_payload)}\r\n" - f"Connection: close\r\n\r\n" - f"{json_payload}" - ) + + headers = [ + f"POST {path} HTTP/1.1", + f"Host: {host}", + f"Content-Type: application/json", + f"Content-Length: {len(json_payload)}", + ] + + if api_key: + headers.append(f"Authorization: Bearer {api_key}") + + headers.append("Connection: close") + + request = "\r\n".join(headers) + "\r\n\r\n" + json_payload writer.write(request.encode()) await writer.drain() - # Read status line status_line = await reader.readline() if not status_line: raise ConnectionError("No response from server") @@ -213,7 +229,16 @@ async def apost(url, json_data): async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult: - COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions" + if args.server: + server_url = args.server.rstrip("/") + # Check if the server URL already contains '/v1/openai' (DeepInfra case) + if "/v1/openai" in server_url: + COMPLETION_URL = f"{server_url}/chat/completions" + else: + COMPLETION_URL = f"{server_url}/v1/chat/completions" + logger.debug(f"Using completion URL: {COMPLETION_URL}") + else: + COMPLETION_URL = f"http://localhost:{BASE_SERVER_PORT}/v1/chat/completions" MAX_RETRIES = args.max_page_retries MODEL_MAX_CONTEXT = 16384 TEMPERATURE_BY_ATTEMPT = [0.1, 0.1, 0.2, 0.3, 0.5, 0.8, 0.9, 1.0] @@ -224,11 +249,19 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: while attempt < MAX_RETRIES: lookup_attempt = min(attempt, len(TEMPERATURE_BY_ATTEMPT) - 1) + # For external servers (like DeepInfra), use the model name from args + # For local inference, always use 'olmocr' + if args.server and hasattr(args, "model"): + model_name = args.model + else: + model_name = "olmocr" + query = await build_page_query( pdf_local_path, page_num, args.target_longest_image_dim, image_rotation=cumulative_rotation, + model_name=model_name, ) # Change temperature as number of attempts increases to overcome repetition issues at expense of quality query["temperature"] = TEMPERATURE_BY_ATTEMPT[lookup_attempt] @@ -242,10 +275,17 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: logger.debug(f"Built page query for {pdf_orig_path}-{page_num}") try: - status_code, response_body = await apost(COMPLETION_URL, json_data=query) + # Passing API key only for external servers that need authentication + if args.server and hasattr(args, "api_key"): + api_key = args.api_key + else: + api_key = None + status_code, response_body = await apost(COMPLETION_URL, json_data=query, api_key=api_key) if status_code == 400: raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response") + elif status_code == 429: + raise ConnectionError(f"Too many requests, doing exponential backoff") elif status_code == 500: raise ValueError(f"Got InternalServerError from server: {response_body}, skipping this response") elif status_code != 200: @@ -596,6 +636,8 @@ async def vllm_server_task(model_name_or_path, args, semaphore, unknown_args=Non str(args.tensor_parallel_size), "--data-parallel-size", str(args.data_parallel_size), + "--limit-mm-per-prompt", + '{"video": 0}', # Disabling video encoder saves RAM that you can put towards the KV cache, thanks @charitarthchugh ] if args.gpu_memory_utilization is not None: @@ -730,15 +772,27 @@ async def vllm_server_host(model_name_or_path, args, semaphore, unknown_args=Non sys.exit(1) -async def vllm_server_ready(): +async def vllm_server_ready(args): max_attempts = 300 delay_sec = 1 - url = f"http://localhost:{BASE_SERVER_PORT}/v1/models" + if args.server: + # Check if the server URL already contains '/v1/openai' (DeepInfra case) + server_url = args.server.rstrip("/") + if "/v1/openai" in server_url: + url = f"{server_url}/models" + else: + url = f"{server_url}/v1/models" + else: + url = f"http://localhost:{BASE_SERVER_PORT}/v1/models" for attempt in range(1, max_attempts + 1): try: + headers = {} + if args.server and hasattr(args, "api_key") and args.api_key: + headers["Authorization"] = f"Bearer {args.api_key}" + async with httpx.AsyncClient() as session: - response = await session.get(url) + response = await session.get(url, headers=headers) if response.status_code == 200: logger.info("vllm server is ready.") @@ -1058,6 +1112,7 @@ async def main(): parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288) parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1) parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs") + parser.add_argument("--api_key", type=str, default=None, help="API key for authenticated remote servers (e.g., DeepInfra)") vllm_group = parser.add_argument_group( "VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM." @@ -1069,6 +1124,11 @@ async def main(): vllm_group.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="Tensor parallel size for vLLM") vllm_group.add_argument("--data-parallel-size", "-dp", type=int, default=1, help="Data parallel size for vLLM") vllm_group.add_argument("--port", type=int, default=30024, help="Port to use for the VLLM server") + vllm_group.add_argument( + "--server", + type=str, + help="URL of external vLLM (or other compatible provider) server (e.g., http://hostname:port). If provided, skips spawning local vLLM instance", + ) # Beaker/job running stuff beaker_group = parser.add_argument_group("beaker/cluster execution") @@ -1207,12 +1267,17 @@ async def main(): # If you get this far, then you are doing inference and need a GPU # check_sglang_version() - check_torch_gpu_available() + if not args.server: + check_torch_gpu_available() logger.info(f"Starting pipeline with PID {os.getpid()}") # Download the model before you do anything else - model_name_or_path = await download_model(args.model) + if args.server: + logger.info(f"Using external server at {args.server}") + model_name_or_path = None + else: + model_name_or_path = await download_model(args.model) # Initialize the work queue qsize = await work_queue.initialize_queue() @@ -1226,9 +1291,12 @@ async def main(): # As soon as one worker is no longer saturating the gpu, the next one can start sending requests semaphore = asyncio.Semaphore(1) - vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args)) + # Start local vLLM instance if not using external one + vllm_server = None + if not args.server: + vllm_server = asyncio.create_task(vllm_server_host(model_name_or_path, args, semaphore, unknown_args)) - await vllm_server_ready() + await vllm_server_ready(args) metrics_task = asyncio.create_task(metrics_reporter(work_queue)) @@ -1241,11 +1309,16 @@ async def main(): # Wait for all worker tasks to finish await asyncio.gather(*worker_tasks) - vllm_server.cancel() + # Cancel vLLM server if it was started + if vllm_server is not None: + vllm_server.cancel() metrics_task.cancel() # Wait for cancelled tasks to complete - await asyncio.gather(vllm_server, metrics_task, return_exceptions=True) + tasks_to_wait = [metrics_task] + if vllm_server is not None: + tasks_to_wait.append(vllm_server) + await asyncio.gather(*tasks_to_wait, return_exceptions=True) # Output final metrics summary metrics_summary = metrics.get_metrics_summary() diff --git a/olmocr/version.py b/olmocr/version.py index bf1c1af..3bbaff7 100644 --- a/olmocr/version.py +++ b/olmocr/version.py @@ -2,7 +2,7 @@ _MAJOR = "0" _MINOR = "3" # On main and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "3" +_PATCH = "8" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = "" diff --git a/pyproject.toml b/pyproject.toml index 163c5eb..52142a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "boto3", "httpx", "torch>=2.7.0", - "transformers==4.53.2", + "transformers==4.55.2", "img2pdf", "beaker-py", ] @@ -51,7 +51,7 @@ Changelog = "https://github.com/allenai/olmocr/blob/main/CHANGELOG.md" [project.optional-dependencies] gpu = [ - "vllm==0.10.0" + "vllm==0.11.0" ] dev = [ diff --git a/scripts/release.sh b/scripts/release.sh index dc5ab60..ef30083 100755 --- a/scripts/release.sh +++ b/scripts/release.sh @@ -68,7 +68,7 @@ read -p "Creating new release for $TAG. Do you want to continue? [Y/n] " prompt if [[ $prompt == "y" || $prompt == "Y" || $prompt == "yes" || $prompt == "Yes" ]]; then python scripts/prepare_changelog.py - git add -A + git add CHANGELOG.md git commit -m "Bump version to $TAG for release" || true && git push echo "Creating new git tag $TAG" git tag "$TAG" -m "$TAG" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index e0d69d9..1541639 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -192,6 +192,7 @@ class MockArgs: max_page_retries: int = 8 target_longest_image_dim: int = 1288 guided_decoding: bool = False + server: str | None = None class TestRotationCorrection: @@ -208,7 +209,7 @@ class TestRotationCorrection: # Counter to track number of API calls call_count = 0 - async def mock_apost(url, json_data): + async def mock_apost(url, json_data, api_key=None): nonlocal call_count call_count += 1 @@ -267,9 +268,9 @@ This is the corrected text from the document.""" build_page_query_calls = [] original_build_page_query = build_page_query - async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): + async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"): build_page_query_calls.append(image_rotation) - return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) + return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name) with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.tracker", mock_tracker): @@ -310,7 +311,7 @@ This is the corrected text from the document.""" # Counter to track number of API calls call_count = 0 - async def mock_apost(url, json_data): + async def mock_apost(url, json_data, api_key=None): nonlocal call_count call_count += 1 @@ -375,9 +376,9 @@ Document is now correctly oriented after 180 degree rotation.""" build_page_query_calls = [] original_build_page_query = build_page_query - async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): + async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"): build_page_query_calls.append(image_rotation) - return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) + return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name) with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.tracker", mock_tracker): @@ -419,7 +420,7 @@ Document is now correctly oriented after 180 degree rotation.""" # Counter to track number of API calls call_count = 0 - async def mock_apost(url, json_data): + async def mock_apost(url, json_data, api_key=None): nonlocal call_count call_count += 1 @@ -481,9 +482,9 @@ Document correctly oriented at 90 degrees total rotation.""" build_page_query_calls = [] original_build_page_query = build_page_query - async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0): + async def mock_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation=0, model_name="olmocr"): build_page_query_calls.append(image_rotation) - return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation) + return await original_build_page_query(local_pdf_path, page, target_longest_image_dim, image_rotation, model_name) with patch("olmocr.pipeline.apost", side_effect=mock_apost): with patch("olmocr.pipeline.tracker", mock_tracker):