Lint fixes

This commit is contained in:
Jake Poznanski 2025-07-23 03:40:05 +00:00
parent 5ec49672ea
commit 6e8272413c
10 changed files with 405 additions and 475 deletions

View File

@ -17,7 +17,7 @@ import time
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from concurrent.futures.process import BrokenProcessPool
from dataclasses import dataclass
from functools import cache, partial
from functools import cache
from io import BytesIO
from urllib.parse import urlparse
@ -34,7 +34,6 @@ from olmocr.check import (
check_torch_gpu_available,
)
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.train.dataloader import FrontMatterParser
from olmocr.filter.filter import Language, PdfFilter
from olmocr.image_utils import convert_image_to_pdf_bytes, is_jpeg, is_png
from olmocr.metrics import MetricsKeeper, WorkerTracker
@ -48,6 +47,7 @@ from olmocr.s3_utils import (
get_s3_bytes_with_backoff,
parse_s3_path,
)
from olmocr.train.dataloader import FrontMatterParser
from olmocr.version import VERSION
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
@ -227,7 +227,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
# Enable guided decoding regex if needed
if args.guided_decoding:
query["guided_regex"] = r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n(?:---|---\n[\s\S]+)"
query["guided_regex"] = (
r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n(?:---|---\n[\s\S]+)"
)
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")
@ -329,6 +331,7 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
is_fallback=True,
)
async def process_pdf(args, worker_id: int, pdf_orig_path: str):
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf", delete=False) as tf:
try:
@ -1016,7 +1019,11 @@ async def main():
)
parser.add_argument("--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve).")
parser.add_argument("--max_model_len", type=int, help="Upper bound (tokens) vLLM will allocate KV-cache for; " "passed through to vllm serve as --max-model-len.",)
parser.add_argument(
"--max_model_len",
type=int,
help="Upper bound (tokens) vLLM will allocate KV-cache for; " "passed through to vllm serve as --max-model-len.",
)
parser.add_argument("--model_max_context", type=int, default="8192", help="Maximum context length that the model was fine tuned under")
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288)
@ -1227,12 +1234,12 @@ async def main():
# Output finished_on_attempt statistics
logger.info("\nPages finished by attempt number:")
total_finished = sum(total_metrics.get(f'finished_on_attempt_{i}', 0) for i in range(args.max_page_retries))
total_finished = sum(total_metrics.get(f"finished_on_attempt_{i}", 0) for i in range(args.max_page_retries))
cumulative = 0
for i in range(args.max_page_retries):
if f'finished_on_attempt_{i}' in total_metrics:
count = total_metrics[f'finished_on_attempt_{i}']
if f"finished_on_attempt_{i}" in total_metrics:
count = total_metrics[f"finished_on_attempt_{i}"]
cumulative += count
percentage = (count / total_finished * 100) if total_finished > 0 else 0
cumulative_percentage = (cumulative / total_finished * 100) if total_finished > 0 else 0

View File

@ -1,8 +1,8 @@
from .prompts import (
PageResponse,
build_finetuning_prompt,
build_openai_silver_data_prompt,
build_no_anchoring_yaml_prompt,
build_openai_silver_data_prompt,
extract_raw_text,
openai_response_format_schema,
)

View File

@ -109,9 +109,9 @@ def build_finetuning_prompt(base_text: str) -> str:
def build_no_anchoring_yaml_prompt() -> str:
return (
f"Attached is one page of a document that you must process. "
f"Just return the plain text representation of this document as if you were reading it naturally. Convert equations to LateX and tables to markdown.\n"
f"Return your output as markdown, with a front matter section on top specifying values for the primary_language, is_rotation_valid, rotation_correction, is_table, and is_diagram parameters."
"Attached is one page of a document that you must process. "
"Just return the plain text representation of this document as if you were reading it naturally. Convert equations to LateX and tables to markdown.\n"
"Return your output as markdown, with a front matter section on top specifying values for the primary_language, is_rotation_valid, rotation_correction, is_table, and is_diagram parameters."
)

View File

@ -6,22 +6,21 @@ Processes prompts and images from WildVision-bench until finding significant mis
import argparse
import asyncio
import gc
import os
import glob
import tempfile
import shutil
import torch
from vllm import LLM, SamplingParams
from transformers import AutoProcessor, AutoModelForVision2Seq
from huggingface_hub import snapshot_download
import random
import numpy as np
from typing import List, Dict
import base64
from io import BytesIO
import PIL.Image
import logging
import os
import random
import shutil
import tempfile
from io import BytesIO
from typing import Dict, List
import numpy as np
import PIL.Image
import torch
from huggingface_hub import snapshot_download
from transformers import AutoModelForVision2Seq, AutoProcessor
from vllm import LLM, SamplingParams
from olmocr.pipeline import build_page_query
from olmocr.s3_utils import download_directory
@ -51,7 +50,7 @@ async def download_model(model_name_or_path: str, max_retries: int = 5):
if retry == max_retries - 1:
raise # Raise on final attempt and fail the job
logger.warning(f"Model download failed (attempt {retry + 1}/{max_retries}), retrying...")
await asyncio.sleep(2 ** retry) # Exponential backoff
await asyncio.sleep(2**retry) # Exponential backoff
def image_to_base64_data_url(image):
@ -71,14 +70,11 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
np.random.seed(seed)
# Import huggingface_hub utilities to list files
from huggingface_hub import list_repo_files, hf_hub_download
from huggingface_hub import hf_hub_download, list_repo_files
# List all PDF files in the repository
print("Listing PDF files in dataset...")
all_files = list_repo_files(
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
repo_type="dataset"
)
all_files = list_repo_files(repo_id="allenai/olmOCR-mix-0225-benchmarkset", repo_type="dataset")
# Filter for PDF files in the pdfs directory
pdf_files = [f for f in all_files if f.startswith("pdfs/") and f.endswith(".pdf")]
@ -104,20 +100,10 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
try:
# Download individual PDF file
print(f"Downloading {pdf_file}...")
local_pdf_path = hf_hub_download(
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
filename=pdf_file,
repo_type="dataset",
local_dir=temp_dir
)
local_pdf_path = hf_hub_download(repo_id="allenai/olmOCR-mix-0225-benchmarkset", filename=pdf_file, repo_type="dataset", local_dir=temp_dir)
# Build page query for page 1 of each PDF
query = await build_page_query(
local_pdf_path=local_pdf_path,
page=1,
target_longest_image_dim=1280,
image_rotation=0
)
query = await build_page_query(local_pdf_path=local_pdf_path, page=1, target_longest_image_dim=1280, image_rotation=0)
queries.append(query)
except Exception as e:
print(f"Error processing {os.path.basename(pdf_file)}: {e}")
@ -126,27 +112,28 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
print(f"Successfully processed {len(queries)} PDFs")
return queries
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
"""Process a single prompt with image and return comparison results."""
# Track if we found the first mismatch for max_prob_first_diff
found_first_mismatch = False
max_prob_first_diff = 0.0
# Extract messages from the sample (which is the output of build_page_query)
messages = sample['messages']
messages = sample["messages"]
# Extract the text prompt and image from the messages
user_message = messages[0]
text_prompt = None
image_base64 = None
for content in user_message['content']:
if content['type'] == 'text':
text_prompt = content['text']
elif content['type'] == 'image_url':
image_url = content['image_url']['url']
for content in user_message["content"]:
if content["type"] == "text":
text_prompt = content["text"]
elif content["type"] == "image_url":
image_url = content["image_url"]["url"]
# Extract base64 data after the comma
if ',' in image_url:
image_base64 = image_url.split(',')[1]
if "," in image_url:
image_base64 = image_url.split(",")[1]
else:
image_base64 = image_url
@ -172,7 +159,11 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
prompt_token_ids = output.prompt_token_ids
generated_token_ids = output.outputs[0].token_ids
print(f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids[:10]}..." if len(prompt_token_ids) > 10 else f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids}")
print(
f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids[:10]}..."
if len(prompt_token_ids) > 10
else f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids}"
)
print(f"Generated tokens ({len(generated_token_ids)}): {generated_token_ids}")
print(f"Generated text: {processor.decode(generated_token_ids, skip_special_tokens=True)}")
@ -182,21 +173,9 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
# HuggingFace forward pass
print("\n=== HuggingFace Forward Pass ===")
# Prepare inputs for HF model using the extracted image and text
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text_prompt}
]
}
]
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text_prompt}]}]
hf_text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = processor(
text=[hf_text_prompt],
images=[image],
return_tensors="pt"
).to(device)
inputs = processor(text=[hf_text_prompt], images=[image], return_tensors="pt").to(device)
print("INPUTS", inputs)
@ -224,7 +203,7 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
# Compare ALL tokens (prompt + generated)
for pos, token_id in enumerate(all_token_ids):
token_str = processor.decode([token_id], skip_special_tokens=False).replace('\n', '\\n').replace('\r', '\\r')
token_str = processor.decode([token_id], skip_special_tokens=False).replace("\n", "\\n").replace("\r", "\\r")
# Determine if this is a prompt or generated token
is_prompt = pos < len(prompt_token_ids)
@ -265,7 +244,7 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
# Decode HF argmax token (only show if mismatch)
hf_token_str = ""
if token_id != hf_argmax:
hf_token_str = processor.decode([hf_argmax], skip_special_tokens=False).replace('\n', '\\n').replace('\r', '\\r')
hf_token_str = processor.decode([hf_argmax], skip_special_tokens=False).replace("\n", "\\n").replace("\r", "\\r")
print(f"{pos:>4} {token_id:>8} {token_str:>20} {token_type:>8} {vllm_prob_str:>12} {hf_argmax:>10} {hf_prob:>12.6f} {match:>6} {hf_token_str:>20}")
else:
@ -296,27 +275,21 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
print("No mismatches found in generated tokens")
return {
'first_mismatch_idx': first_mismatch_idx,
'max_prob_first_diff': max_prob_first_diff,
'match_rate': match_rate,
'num_generated': len(generated_token_ids)
"first_mismatch_idx": first_mismatch_idx,
"max_prob_first_diff": max_prob_first_diff,
"match_rate": match_rate,
"num_generated": len(generated_token_ids),
}
async def async_main():
parser = argparse.ArgumentParser(description="Batch compare VLM inference between vLLM and HuggingFace")
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct",
help="Model name or path")
parser.add_argument("--max-tokens", type=int, default=20,
help="Maximum tokens to generate per prompt")
parser.add_argument("--temperature", type=float, default=0.0,
help="Sampling temperature")
parser.add_argument("--num-prompts", type=int, default=100,
help="Number of prompts to load from WildVision")
parser.add_argument("--prob-threshold", type=float, default=0.20,
help="Probability difference threshold to stop")
parser.add_argument("--seed", type=int, default=42,
help="Random seed for prompt selection")
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct", help="Model name or path")
parser.add_argument("--max-tokens", type=int, default=20, help="Maximum tokens to generate per prompt")
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
parser.add_argument("--num-prompts", type=int, default=100, help="Number of prompts to load from WildVision")
parser.add_argument("--prob-threshold", type=float, default=0.20, help="Probability difference threshold to stop")
parser.add_argument("--seed", type=int, default=42, help="Random seed for prompt selection")
args = parser.parse_args()
print(f"Model: {args.model}")
@ -335,22 +308,13 @@ async def async_main():
print("\n=== Loading HuggingFace Model ===")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor_hf = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
hf_model = AutoModelForVision2Seq.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
hf_model = AutoModelForVision2Seq.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
hf_model.eval()
# Create vLLM engine once
print("\n=== Creating vLLM Engine ===")
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_tokens,
logprobs=1 # Get top-1 logprobs
)
sampling_params = SamplingParams(temperature=args.temperature, max_tokens=args.max_tokens, logprobs=1) # Get top-1 logprobs
# Process samples until finding significant mismatch
print("\n=== Processing Samples ===")
@ -367,7 +331,7 @@ async def async_main():
all_results.append(result)
# Check if we found significant mismatch
if result['first_mismatch_idx'] is not None and result['max_prob_first_diff'] > args.prob_threshold:
if result["first_mismatch_idx"] is not None and result["max_prob_first_diff"] > args.prob_threshold:
print(f"\n\n{'*'*80}")
print(f"*** FOUND SIGNIFICANT MISMATCH ***")
print(f"*** First mismatch probability difference: {result['max_prob_first_diff']:.6f} > {args.prob_threshold} ***")
@ -380,20 +344,20 @@ async def async_main():
print(f"{'='*80}")
total_samples = len(all_results)
samples_with_mismatches = sum(1 for r in all_results if r['first_mismatch_idx'] is not None)
total_tokens_generated = sum(r['num_generated'] for r in all_results)
samples_with_mismatches = sum(1 for r in all_results if r["first_mismatch_idx"] is not None)
total_tokens_generated = sum(r["num_generated"] for r in all_results)
print(f"Total samples processed: {total_samples}")
print(f"Samples with mismatches: {samples_with_mismatches} ({samples_with_mismatches/total_samples*100:.1f}%)")
print(f"Total tokens generated: {total_tokens_generated}")
if samples_with_mismatches > 0:
avg_match_rate = sum(r['match_rate'] for r in all_results) / total_samples
max_prob_diffs = [r['max_prob_first_diff'] for r in all_results if r['first_mismatch_idx'] is not None]
avg_match_rate = sum(r["match_rate"] for r in all_results) / total_samples
max_prob_diffs = [r["max_prob_first_diff"] for r in all_results if r["first_mismatch_idx"] is not None]
avg_prob_diff = sum(max_prob_diffs) / len(max_prob_diffs)
max_prob_diff_overall = max(max_prob_diffs)
first_mismatch_positions = [r['first_mismatch_idx'] for r in all_results if r['first_mismatch_idx'] is not None]
first_mismatch_positions = [r["first_mismatch_idx"] for r in all_results if r["first_mismatch_idx"] is not None]
avg_first_mismatch_pos = sum(first_mismatch_positions) / len(first_mismatch_positions)
print(f"\nMismatch Statistics:")

View File

@ -26,18 +26,22 @@ import shutil
import tempfile
from io import BytesIO
from pathlib import Path
from typing import Optional, Tuple, Union, List
from typing import List, Optional, Tuple, Union
import boto3
import torch
from datasets import Dataset
from llmcompressor import oneshot
from PIL import Image
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import (
AutoProcessor,
AutoTokenizer,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
)
from olmocr.s3_utils import parse_s3_path
from olmocr.pipeline import build_page_query
from olmocr.s3_utils import parse_s3_path
s3_client = boto3.client("s3")
@ -62,7 +66,7 @@ def get_calibration_pdfs(num_samples: int, pdf_paths: List[str]) -> List[str]:
# Verify all PDFs exist
valid_paths = []
for path in pdf_paths:
if os.path.exists(path) and path.endswith('.pdf'):
if os.path.exists(path) and path.endswith(".pdf"):
valid_paths.append(path)
else:
print(f" Warning: Skipping invalid path: {path}")
@ -101,9 +105,7 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Datase
images.append(image)
# Apply chat template to get the text
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Process with tokenizer
inputs = processor(
@ -123,7 +125,7 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Datase
all_datasets = []
for i in range(0, len(dataset_items), batch_size):
batch = dataset_items[i:i + batch_size]
batch = dataset_items[i : i + batch_size]
# Flatten the batch into a dict of lists
batch_dict = {}
for key in batch[0].keys():
@ -138,6 +140,7 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Datase
return all_datasets[0]
else:
from datasets import concatenate_datasets
return concatenate_datasets(all_datasets)
else:
return Dataset.from_dict({})
@ -185,7 +188,9 @@ def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
print(f" Uploaded {rel_path}")
def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
def load_model_and_tokenizer(
source_path: str,
) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
"""Load model and tokenizer from source path (local or S3)."""
if is_s3_path(source_path):
# Download from S3 to temporary directory
@ -210,35 +215,19 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditio
# Load appropriate model class based on name
if "Qwen2.5-VL" in model_name:
print("Detected Qwen2.5-VL model")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
elif "Qwen2-VL" in model_name:
print("Detected Qwen2-VL model")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
else:
# Default to checking architectures list
architectures = config.get("architectures", [])
if "Qwen2_5_VLForConditionalGeneration" in architectures:
print("Detected Qwen2.5-VL model from architectures")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
else:
print("Detected Qwen2-VL model from architectures")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
print(f"Loading tokenizer from {model_path}...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
@ -268,7 +257,9 @@ def data_collator(batch):
return {key: torch.tensor(value) for key, value in batch[0].items()}
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 512, calibration_pdfs: Optional[List[str]] = None) -> None:
def compress_checkpoint(
source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 512, calibration_pdfs: Optional[List[str]] = None
) -> None:
"""Compress OlmOCR checkpoint using FP8 quantization."""
# Load model and tokenizer
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
@ -304,14 +295,7 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_
print(f"\nApplying quantization using recipe: {recipe_path}")
if dataset:
oneshot(
model=model,
recipe=recipe_path,
dataset=dataset,
max_seq_length=8192,
num_calibration_samples=len(dataset),
data_collator=data_collator
)
oneshot(model=model, recipe=recipe_path, dataset=dataset, max_seq_length=8192, num_calibration_samples=len(dataset), data_collator=data_collator)
else:
oneshot(model=model, recipe=recipe_path)
@ -377,15 +361,18 @@ Examples:
# Using recursive glob pattern
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/**/*.pdf"
"""
""",
)
parser.add_argument("source", help="Source checkpoint path (local or S3)")
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
parser.add_argument("--num-calibration-samples", type=int, default=512,
help="Number of calibration samples to use (default: 512, set to 0 to disable)")
parser.add_argument("--calibration-pdfs", type=str, default=None,
help="Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/*.pdf' or '/data/**/*.pdf'). Required when num-calibration-samples > 0.")
parser.add_argument("--num-calibration-samples", type=int, default=512, help="Number of calibration samples to use (default: 512, set to 0 to disable)")
parser.add_argument(
"--calibration-pdfs",
type=str,
default=None,
help="Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/*.pdf' or '/data/**/*.pdf'). Required when num-calibration-samples > 0.",
)
args = parser.parse_args()
@ -396,29 +383,27 @@ Examples:
pattern = args.calibration_pdfs
# Handle both absolute and relative paths with recursive glob
if '**' in pattern:
if "**" in pattern:
# For recursive patterns, we need to handle them specially
if pattern.startswith('/'):
if pattern.startswith("/"):
# Absolute path with **
parts = pattern.split('**')
parts = pattern.split("**")
base_path = Path(parts[0])
glob_pattern = '**' + parts[1] if len(parts) > 1 else '**/*.pdf'
calibration_pdfs = list(base_path.glob(glob_pattern.lstrip('/')))
glob_pattern = "**" + parts[1] if len(parts) > 1 else "**/*.pdf"
calibration_pdfs = list(base_path.glob(glob_pattern.lstrip("/")))
else:
# Relative path with **
calibration_pdfs = list(Path('.').glob(pattern))
calibration_pdfs = [str(p.absolute()) for p in calibration_pdfs if p.suffix.lower() == '.pdf']
calibration_pdfs = list(Path(".").glob(pattern))
calibration_pdfs = [str(p.absolute()) for p in calibration_pdfs if p.suffix.lower() == ".pdf"]
else:
# Use standard glob for non-recursive patterns
calibration_pdfs = glob.glob(pattern)
calibration_pdfs = [p for p in calibration_pdfs if p.lower().endswith('.pdf')]
calibration_pdfs = [p for p in calibration_pdfs if p.lower().endswith(".pdf")]
print(f"Found {len(calibration_pdfs)} PDF files matching pattern: {args.calibration_pdfs}")
compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples, calibration_pdfs)
if __name__ == "__main__":
exit(main())

View File

@ -6,7 +6,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
from omegaconf import DictConfig, OmegaConf
from omegaconf import OmegaConf
@dataclass
@ -314,10 +314,10 @@ class Config:
FrontMatterOutputFormat,
FrontMatterParser,
InstructUserMessages,
JSONOutputFormat,
LatexBracketNormalizer,
NewYamlFinetuningPromptWithAnchoring,
NewYamlFinetuningPromptWithNoAnchoring,
JSONOutputFormat,
PDFRenderer,
RandomTokenFlipper,
StaticLengthDocumentAnchoring,

View File

@ -1,4 +1,5 @@
import base64
import json
import logging
import re
from abc import ABC, abstractmethod
@ -8,8 +9,7 @@ from functools import reduce
from io import BytesIO
from os import PathLike
from pathlib import Path
import json
from typing import Any, Callable, Dict, List, Optional, Type, TypeAlias, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeAlias
import numpy as np
import yaml
@ -60,6 +60,7 @@ def validate_pdf_pair(md_path: Path) -> Tuple[Optional[Dict[str, Path]], Optiona
# Test that document anchoring works
from olmocr.prompts.anchor import get_anchor_text
get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100)
return {"markdown_path": md_path, "pdf_path": pdf_path}, None
@ -205,11 +206,11 @@ class FrontMatterParser(PipelineStep):
value = front_matter_dict[field_name]
# Handle type conversions
if field_type == int and isinstance(value, str):
if field_type is int and isinstance(value, str):
kwargs[field_name] = int(value)
elif field_type == bool and isinstance(value, str):
elif field_type is bool and isinstance(value, str):
kwargs[field_name] = value.lower() == "true"
elif field_type == Optional[str]:
elif field_type is Optional[str]:
kwargs[field_name] = value if value else None
else:
kwargs[field_name] = value
@ -323,7 +324,7 @@ class FrontMatterOutputFormat(PipelineStep):
def __call__(self, sample: Sample) -> Sample:
page_data = sample["page_data"]
assert type(page_data) == PageResponse
assert type(page_data) is PageResponse
sample["response"] = (
f"""---
@ -346,19 +347,23 @@ class JSONOutputFormat(PipelineStep):
def __call__(self, sample: Sample) -> Sample:
page_data = sample["page_data"]
assert type(page_data) == PageResponse
assert type(page_data) is PageResponse
sample["response"] = json.dumps({
"primary_language": page_data.primary_language,
"is_rotation_valid": page_data.is_rotation_valid,
"rotation_correction": page_data.rotation_correction,
"is_table": page_data.is_table,
"is_diagram": page_data.is_diagram,
"natural_text": page_data.natural_text
}, ensure_ascii=False)
sample["response"] = json.dumps(
{
"primary_language": page_data.primary_language,
"is_rotation_valid": page_data.is_rotation_valid,
"rotation_correction": page_data.rotation_correction,
"is_table": page_data.is_table,
"is_diagram": page_data.is_diagram,
"natural_text": page_data.natural_text,
},
ensure_ascii=False,
)
return sample
@dataclass(frozen=True, slots=True)
class LatexBracketNormalizer(PipelineStep):
"""Normalizes LaTeX brackets in natural text field."""
@ -379,7 +384,7 @@ class LatexBracketNormalizer(PipelineStep):
# Order matters: process display math first, then inline
patterns = [
(r"\$\$(.+?)\$\$", r"\[\1\]"), # $$...$$ to \[...\]
(r"\$(.+?)\$", r"\(\1\)"), # $...$ to \(...\)
(r"\$(.+?)\$", r"\(\1\)"), # $...$ to \(...\)
]
# Apply replacements
@ -389,13 +394,14 @@ class LatexBracketNormalizer(PipelineStep):
# Update the page_data with normalized text
# Since PageResponse is frozen, we need to create a new instance
from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse(
primary_language=page_data.primary_language,
is_rotation_valid=page_data.is_rotation_valid,
rotation_correction=page_data.rotation_correction,
is_table=page_data.is_table,
is_diagram=page_data.is_diagram,
natural_text=text
natural_text=text,
)
sample["page_data"] = new_page_data
@ -608,6 +614,7 @@ if __name__ == "__main__":
# Load processor for tokenization
print(f"\nLoading processor: {config.model.name}")
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(config.model.name)
# Select dataset based on type
@ -655,23 +662,23 @@ if __name__ == "__main__":
print("\nSample keys:", list(sample.keys()))
# If it's raw data (no tokenization)
if 'markdown_path' in sample:
if "markdown_path" in sample:
print(f"\nMarkdown file: {sample['markdown_path'].name}")
if 'pdf_path' in sample:
if "pdf_path" in sample:
print(f"PDF file: {sample['pdf_path'].name}")
if 'image' in sample and hasattr(sample['image'], 'size'):
if "image" in sample and hasattr(sample["image"], "size"):
print(f"Image size: {sample['image'].size}")
if 'page_data' in sample:
if "page_data" in sample:
print(f"\nPage data: {sample['page_data']}")
if 'messages' in sample:
if "messages" in sample:
print(f"\n=== Messages ===")
for i, msg in enumerate(sample['messages']):
for i, msg in enumerate(sample["messages"]):
print(f"\nMessage {i}:")
print(f" Role: {msg['role']}")
print(f" Content preview: {str(msg['content'])[:200]}...")
# If it's tokenized data
if 'input_ids' in sample:
if "input_ids" in sample:
print(f"\n=== Tokenized Output ===")
print(f" Keys: {list(sample.keys())}")
print(f" Input IDs shape: {sample['input_ids'].shape}")
@ -749,7 +756,7 @@ if __name__ == "__main__":
print(f" Total sequence length: {total_count}")
# Analyze token length distribution across entire dataset
if 'input_ids' in sample:
if "input_ids" in sample:
print(f"\n\n=== Analyzing token length distribution across entire dataset ===")
print(f"Processing {len(dataset)} samples...")
@ -757,9 +764,9 @@ if __name__ == "__main__":
def process_sample(idx):
try:
current_sample = dataset[idx]
if 'labels' in current_sample:
if "labels" in current_sample:
# Count total sequence length (all tokens, prompt + completion)
labels = current_sample['labels']
labels = current_sample["labels"]
total_length = len(labels)
return (idx, total_length, None)
return (idx, None, "No labels in sample")
@ -774,6 +781,7 @@ if __name__ == "__main__":
# Determine number of workers (use fewer workers to avoid memory issues)
import multiprocessing
num_workers = min(multiprocessing.cpu_count() // 2, 8)
with ProcessPoolExecutor(max_workers=num_workers) as executor:

View File

@ -38,9 +38,9 @@ import tempfile
import boto3
import requests
import torch
from smart_open import smart_open
from tqdm import tqdm
import torch
try:
from safetensors.torch import load_file, save_file
@ -50,32 +50,16 @@ except ImportError:
from olmocr.s3_utils import parse_s3_path
# Hugging Face model IDs for tokenizer files
HF_MODEL_IDS = {
"Qwen2VLForConditionalGeneration": "Qwen/Qwen2-VL-7B-Instruct",
"Qwen2_5_VLForConditionalGeneration": "Qwen/Qwen2.5-VL-7B-Instruct"
}
HF_MODEL_IDS = {"Qwen2VLForConditionalGeneration": "Qwen/Qwen2-VL-7B-Instruct", "Qwen2_5_VLForConditionalGeneration": "Qwen/Qwen2.5-VL-7B-Instruct"}
# Required tokenizer files to download from Hugging Face
TOKENIZER_FILES = [
"chat_template.json",
"merges.txt",
"preprocessor_config.json",
"tokenizer.json",
"tokenizer_config.json",
"vocab.json"
]
TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"]
# Supported model architectures
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
# Files to exclude from copying (training-related files)
EXCLUDED_FILES = {
"optimizer.pt",
"scheduler.pt",
"rng_state.pth",
"trainer_state.json",
"training_args.bin"
}
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin"}
s3_client = boto3.client("s3")
@ -125,10 +109,7 @@ def detect_checkpoint_architecture(config_path: str) -> str:
elif "Qwen2-VL" in model_name:
detected_architecture = "Qwen2VLForConditionalGeneration"
else:
raise ValueError(
f"No supported architecture found. Expected one of {SUPPORTED_ARCHITECTURES} "
f"but found: {architectures}"
)
raise ValueError(f"No supported architecture found. Expected one of {SUPPORTED_ARCHITECTURES} " f"but found: {architectures}")
print(f"✓ Detected architecture: {detected_architecture}")
return detected_architecture
@ -194,10 +175,7 @@ def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> N
print(f"Downloading {len(download_tasks)} files from s3://{source_bucket}/{source_prefix} to {dest_dir}...")
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [
executor.submit(download_file_from_s3, bucket, key, local_path)
for bucket, key, local_path in download_tasks
]
futures = [executor.submit(download_file_from_s3, bucket, key, local_path) for bucket, key, local_path in download_tasks]
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Downloading"):
future.result()
@ -220,10 +198,7 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non
print(f"Uploading {len(upload_tasks)} files from {source_dir} to s3://{dest_bucket}/{dest_prefix}...")
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [
executor.submit(upload_file_to_s3, local_path, bucket, key)
for local_path, bucket, key in upload_tasks
]
futures = [executor.submit(upload_file_to_s3, local_path, bucket, key) for local_path, bucket, key in upload_tasks]
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Uploading"):
future.result()
@ -350,10 +325,10 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
souped_path = os.path.join(souped_dir, rel_path)
os.makedirs(os.path.dirname(souped_path), exist_ok=True)
if file_path.endswith('.safetensors'):
sum_state = load_file(file_path, device='cpu')
if file_path.endswith(".safetensors"):
sum_state = load_file(file_path, device="cpu")
for other_path in all_paths[1:]:
other_state = load_file(other_path, device='cpu')
other_state = load_file(other_path, device="cpu")
if set(sum_state.keys()) != set(other_state.keys()):
raise ValueError(f"Key mismatch in {rel_path}")
for k in sum_state:
@ -363,10 +338,10 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
for k in sum_state:
sum_state[k] /= n
save_file(sum_state, souped_path)
elif file_path.endswith('.bin'):
sum_state = torch.load(file_path, map_location='cpu')
elif file_path.endswith(".bin"):
sum_state = torch.load(file_path, map_location="cpu")
for other_path in all_paths[1:]:
other_state = torch.load(other_path, map_location='cpu')
other_state = torch.load(other_path, map_location="cpu")
if set(sum_state.keys()) != set(other_state.keys()):
raise ValueError(f"Key mismatch in {rel_path}")
for k in sum_state:
@ -396,10 +371,7 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
# Download files
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
futures = [
executor.submit(download_file_from_hf, filename, temp_dir, hf_base_url)
for filename in TOKENIZER_FILES
]
futures = [executor.submit(download_file_from_hf, filename, temp_dir, hf_base_url) for filename in TOKENIZER_FILES]
for future in concurrent.futures.as_completed(futures):
future.result()
@ -413,19 +385,13 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
print("Uploading tokenizer files to S3...")
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
futures = [
executor.submit(upload_file_to_s3, local_path, bucket, key)
for local_path, bucket, key in upload_tasks
]
futures = [executor.submit(upload_file_to_s3, local_path, bucket, key) for local_path, bucket, key in upload_tasks]
for future in concurrent.futures.as_completed(futures):
future.result()
else:
# Download directly to destination
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
futures = [
executor.submit(download_file_from_hf, filename, dest_path, hf_base_url)
for filename in TOKENIZER_FILES
]
futures = [executor.submit(download_file_from_hf, filename, dest_path, hf_base_url) for filename in TOKENIZER_FILES]
for future in concurrent.futures.as_completed(futures):
future.result()
@ -436,9 +402,9 @@ def main():
parser = argparse.ArgumentParser(
description="Prepare OlmOCR checkpoint for deployment",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__.split("Usage:")[1] # Use the docstring for epilog
epilog=__doc__.split("Usage:")[1], # Use the docstring for epilog
)
parser.add_argument("paths", nargs='+', help="One or more source paths followed by destination path (local or S3)")
parser.add_argument("paths", nargs="+", help="One or more source paths followed by destination path (local or S3)")
args = parser.parse_args()

View File

@ -5,6 +5,7 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse
import logging
import os
from typing import Optional
import numpy as np
import torch
@ -18,7 +19,6 @@ from transformers import (
TrainingArguments,
)
from typing import Optional
from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset
@ -50,9 +50,9 @@ class QwenDataCollator:
# Trim to max_token_len if specified
if self.max_token_len is not None:
input_ids = input_ids[:self.max_token_len]
attention_mask = attention_mask[:self.max_token_len]
labels = labels[:self.max_token_len]
input_ids = input_ids[: self.max_token_len]
attention_mask = attention_mask[: self.max_token_len]
labels = labels[: self.max_token_len]
batch["input_ids"].append(input_ids)
batch["attention_mask"].append(attention_mask)
@ -209,7 +209,7 @@ def main():
adam_epsilon=config.training.adam_epsilon,
weight_decay=config.training.weight_decay,
max_grad_norm=config.training.max_grad_norm,
bf16=True, # We're sticking with this known good reduced precision option
bf16=True, # We're sticking with this known good reduced precision option
eval_strategy=config.training.evaluation_strategy,
eval_steps=config.training.eval_steps,
save_strategy=config.training.save_strategy,

View File

@ -167,7 +167,7 @@ reportPrivateImportUsage = false
line-length = 160
target-version = "py311"
exclude = ["olmocr/train/molmo", "tests/*"]
ignore = ["E722"] #igore bare except
ignore = ["E722", "F541"] #igore bare except, and f string without placeholders
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]