This commit is contained in:
Jake Poznanski 2025-10-20 18:43:13 +00:00
parent 4fc9cd112b
commit eaf83026d3
2 changed files with 17 additions and 28 deletions

View File

@ -1,21 +1,10 @@
import json
from typing import Literal
import httpx
import base64 import base64
import subprocess import subprocess
from PIL import Image
from olmocr.bench.prompts import ( import httpx
build_basic_prompt,
build_openai_silver_data_prompt_no_document_anchoring,
)
from olmocr.data.renderpdf import get_pdf_media_box_width_height from olmocr.data.renderpdf import get_pdf_media_box_width_height
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import (
PageResponse,
build_finetuning_prompt,
build_openai_silver_data_prompt,
)
# Logic to set min size from here: https://github.com/NanoNets/Nanonets-OCR2/blob/main/Nanonets-OCR2-Cookbook/image2md.ipynb # Logic to set min size from here: https://github.com/NanoNets/Nanonets-OCR2/blob/main/Nanonets-OCR2-Cookbook/image2md.ipynb
def render_pdf_to_base64png_min_short_size(local_pdf_path: str, page_num: int, target_shortest_dim: int = 2048) -> str: def render_pdf_to_base64png_min_short_size(local_pdf_path: str, page_num: int, target_shortest_dim: int = 2048) -> str:
@ -64,19 +53,21 @@ async def run_server(
""" """
# Convert the first page of the PDF to a base64-encoded PNG image. # Convert the first page of the PDF to a base64-encoded PNG image.
image_base64 = render_pdf_to_base64png_min_short_size(pdf_path, page_num=page_num, target_shortest_dim=page_dimensions) image_base64 = render_pdf_to_base64png_min_short_size(pdf_path, page_num=page_num, target_shortest_dim=page_dimensions)
# Now use th # Now use th
prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes.""" prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the <img></img> tag; otherwise, add the image caption inside <img></img>. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and ☑ for check boxes."""
request = { request = {
"model": model, "model": model,
"messages": [ "messages": [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [ {
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}, "role": "user",
{"type": "text", "text": prompt}, "content": [
]}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
{"type": "text", "text": prompt},
],
},
], ],
"temperature": temperature, "temperature": temperature,
"max_tokens": 4096, "max_tokens": 4096,

View File

@ -214,17 +214,17 @@ async def apost(url, json_data, api_key=None):
# Read chunk size line # Read chunk size line
size_line = await reader.readline() size_line = await reader.readline()
chunk_size = int(size_line.strip(), 16) # Hex format chunk_size = int(size_line.strip(), 16) # Hex format
if chunk_size == 0: if chunk_size == 0:
await reader.readline() # Read final CRLF await reader.readline() # Read final CRLF
break break
chunk_data = await reader.readexactly(chunk_size) chunk_data = await reader.readexactly(chunk_size)
chunks.append(chunk_data) chunks.append(chunk_data)
# Read trailing CRLF after chunk data # Read trailing CRLF after chunk data
await reader.readline() await reader.readline()
response_body = b"".join(chunks) response_body = b"".join(chunks)
elif headers.get("connection", "") == "close": elif headers.get("connection", "") == "close":
# Read until connection closes # Read until connection closes
@ -1115,7 +1115,6 @@ async def main():
) )
server_group.add_argument("--api_key", type=str, default=None, help="API key for authenticated remote servers (e.g., DeepInfra)") server_group.add_argument("--api_key", type=str, default=None, help="API key for authenticated remote servers (e.g., DeepInfra)")
vllm_group = parser.add_argument_group( vllm_group = parser.add_argument_group(
"VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM." "VLLM arguments", "These arguments are passed to vLLM. Any unrecognized arguments are also automatically forwarded to vLLM."
) )
@ -1127,7 +1126,6 @@ async def main():
vllm_group.add_argument("--data-parallel-size", "-dp", type=int, default=1, help="Data parallel size for vLLM") vllm_group.add_argument("--data-parallel-size", "-dp", type=int, default=1, help="Data parallel size for vLLM")
vllm_group.add_argument("--port", type=int, default=30024, help="Port to use for the VLLM server") vllm_group.add_argument("--port", type=int, default=30024, help="Port to use for the VLLM server")
# Beaker/job running stuff # Beaker/job running stuff
beaker_group = parser.add_argument_group("beaker/cluster execution") beaker_group = parser.add_argument_group("beaker/cluster execution")
beaker_group.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally") beaker_group.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
@ -1277,7 +1275,7 @@ async def main():
if use_internal_server: if use_internal_server:
model_name_or_path = await download_model(args.model) model_name_or_path = await download_model(args.model)
args.server = f"http://localhost:{args.port}/v1" args.server = f"http://localhost:{args.port}/v1"
args.model = "olmocr" # Internal server always uses this name for the model, for supporting weird local model paths args.model = "olmocr" # Internal server always uses this name for the model, for supporting weird local model paths
logger.info(f"Using internal server at {args.server}") logger.info(f"Using internal server at {args.server}")
else: else:
logger.info(f"Using external server at {args.server}") logger.info(f"Using external server at {args.server}")