Lint fixes

This commit is contained in:
Jake Poznanski 2025-07-23 03:40:05 +00:00
parent 5ec49672ea
commit 6e8272413c
10 changed files with 405 additions and 475 deletions

View File

@ -17,7 +17,7 @@ import time
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from concurrent.futures.process import BrokenProcessPool
from dataclasses import dataclass
from functools import cache, partial
from functools import cache
from io import BytesIO
from urllib.parse import urlparse
@ -34,7 +34,6 @@ from olmocr.check import (
check_torch_gpu_available,
)
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.train.dataloader import FrontMatterParser
from olmocr.filter.filter import Language, PdfFilter
from olmocr.image_utils import convert_image_to_pdf_bytes, is_jpeg, is_png
from olmocr.metrics import MetricsKeeper, WorkerTracker
@ -48,6 +47,7 @@ from olmocr.s3_utils import (
get_s3_bytes_with_backoff,
parse_s3_path,
)
from olmocr.train.dataloader import FrontMatterParser
from olmocr.version import VERSION
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
@ -227,7 +227,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
# Enable guided decoding regex if needed
if args.guided_decoding:
query["guided_regex"] = r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n(?:---|---\n[\s\S]+)"
query["guided_regex"] = (
r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n(?:---|---\n[\s\S]+)"
)
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")
@ -247,7 +249,7 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
local_anchor_text_len = max(1, local_anchor_text_len // 2)
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_orig_path}-{page_num}")
raise ValueError("Response exceeded model_max_context, cannot use this response")
if base_response_data["choices"][0]["finish_reason"] != "stop":
local_anchor_text_len = max(1, local_anchor_text_len // 2)
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_orig_path}-{page_num}")
@ -329,6 +331,7 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
is_fallback=True,
)
async def process_pdf(args, worker_id: int, pdf_orig_path: str):
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf", delete=False) as tf:
try:
@ -586,9 +589,9 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
if args.gpu_memory_utilization is not None:
cmd.extend(["--gpu-memory-utilization", str(args.gpu_memory_utilization)])
if args.max_model_len is not None:
cmd.extend(["--max-model-len", str(args.max_model_len)])
cmd.extend(["--max-model-len", str(args.max_model_len)])
proc = await asyncio.create_subprocess_exec(
*cmd,
@ -1016,7 +1019,11 @@ async def main():
)
parser.add_argument("--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve).")
parser.add_argument("--max_model_len", type=int, help="Upper bound (tokens) vLLM will allocate KV-cache for; " "passed through to vllm serve as --max-model-len.",)
parser.add_argument(
"--max_model_len",
type=int,
help="Upper bound (tokens) vLLM will allocate KV-cache for; " "passed through to vllm serve as --max-model-len.",
)
parser.add_argument("--model_max_context", type=int, default="8192", help="Maximum context length that the model was fine tuned under")
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288)
@ -1041,7 +1048,7 @@ async def main():
logger.info(
"If you run out of GPU memory during start-up or get 'KV cache is larger than available memory' errors, retry with lower values, e.g. --gpu_memory_utilization 0.80 --max_model_len 16384"
)
global workspace_s3, pdf_s3
# set the global BASE_SERVER_PORT from args
global BASE_SERVER_PORT
@ -1227,12 +1234,12 @@ async def main():
# Output finished_on_attempt statistics
logger.info("\nPages finished by attempt number:")
total_finished = sum(total_metrics.get(f'finished_on_attempt_{i}', 0) for i in range(args.max_page_retries))
total_finished = sum(total_metrics.get(f"finished_on_attempt_{i}", 0) for i in range(args.max_page_retries))
cumulative = 0
for i in range(args.max_page_retries):
if f'finished_on_attempt_{i}' in total_metrics:
count = total_metrics[f'finished_on_attempt_{i}']
if f"finished_on_attempt_{i}" in total_metrics:
count = total_metrics[f"finished_on_attempt_{i}"]
cumulative += count
percentage = (count / total_finished * 100) if total_finished > 0 else 0
cumulative_percentage = (cumulative / total_finished * 100) if total_finished > 0 else 0
@ -1253,4 +1260,4 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@ -1,8 +1,8 @@
from .prompts import (
PageResponse,
build_finetuning_prompt,
build_openai_silver_data_prompt,
build_no_anchoring_yaml_prompt,
build_openai_silver_data_prompt,
extract_raw_text,
openai_response_format_schema,
)

View File

@ -109,9 +109,9 @@ def build_finetuning_prompt(base_text: str) -> str:
def build_no_anchoring_yaml_prompt() -> str:
return (
f"Attached is one page of a document that you must process. "
f"Just return the plain text representation of this document as if you were reading it naturally. Convert equations to LateX and tables to markdown.\n"
f"Return your output as markdown, with a front matter section on top specifying values for the primary_language, is_rotation_valid, rotation_correction, is_table, and is_diagram parameters."
"Attached is one page of a document that you must process. "
"Just return the plain text representation of this document as if you were reading it naturally. Convert equations to LateX and tables to markdown.\n"
"Return your output as markdown, with a front matter section on top specifying values for the primary_language, is_rotation_valid, rotation_correction, is_table, and is_diagram parameters."
)

View File

@ -6,22 +6,21 @@ Processes prompts and images from WildVision-bench until finding significant mis
import argparse
import asyncio
import gc
import os
import glob
import tempfile
import shutil
import torch
from vllm import LLM, SamplingParams
from transformers import AutoProcessor, AutoModelForVision2Seq
from huggingface_hub import snapshot_download
import random
import numpy as np
from typing import List, Dict
import base64
from io import BytesIO
import PIL.Image
import logging
import os
import random
import shutil
import tempfile
from io import BytesIO
from typing import Dict, List
import numpy as np
import PIL.Image
import torch
from huggingface_hub import snapshot_download
from transformers import AutoModelForVision2Seq, AutoProcessor
from vllm import LLM, SamplingParams
from olmocr.pipeline import build_page_query
from olmocr.s3_utils import download_directory
@ -51,7 +50,7 @@ async def download_model(model_name_or_path: str, max_retries: int = 5):
if retry == max_retries - 1:
raise # Raise on final attempt and fail the job
logger.warning(f"Model download failed (attempt {retry + 1}/{max_retries}), retrying...")
await asyncio.sleep(2 ** retry) # Exponential backoff
await asyncio.sleep(2**retry) # Exponential backoff
def image_to_base64_data_url(image):
@ -65,38 +64,35 @@ def image_to_base64_data_url(image):
async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: int = 2048) -> List[Dict[str, str]]:
"""Load prompts and images from olmOCR-mix-0225-benchmarkset dataset with fixed random seed."""
print(f"Loading olmOCR-mix-0225-benchmarkset dataset with {num_samples} samples and seed {seed}")
# Set random seed for reproducibility
random.seed(seed)
np.random.seed(seed)
# Import huggingface_hub utilities to list files
from huggingface_hub import list_repo_files, hf_hub_download
from huggingface_hub import hf_hub_download, list_repo_files
# List all PDF files in the repository
print("Listing PDF files in dataset...")
all_files = list_repo_files(
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
repo_type="dataset"
)
all_files = list_repo_files(repo_id="allenai/olmOCR-mix-0225-benchmarkset", repo_type="dataset")
# Filter for PDF files in the pdfs directory
pdf_files = [f for f in all_files if f.startswith("pdfs/") and f.endswith(".pdf")]
if not pdf_files:
raise ValueError("No PDF files found in the dataset")
print(f"Found {len(pdf_files)} PDF files in dataset")
# Randomly sample num_samples PDFs
if len(pdf_files) > num_samples:
sampled_pdf_files = random.sample(pdf_files, num_samples)
else:
sampled_pdf_files = pdf_files
print(f"Warning: Only {len(pdf_files)} PDFs available, less than requested {num_samples}")
print(f"Sampled {len(sampled_pdf_files)} PDFs to download")
# Download only the sampled PDFs and process them
queries = []
with tempfile.TemporaryDirectory() as temp_dir:
@ -104,99 +100,82 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
try:
# Download individual PDF file
print(f"Downloading {pdf_file}...")
local_pdf_path = hf_hub_download(
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
filename=pdf_file,
repo_type="dataset",
local_dir=temp_dir
)
local_pdf_path = hf_hub_download(repo_id="allenai/olmOCR-mix-0225-benchmarkset", filename=pdf_file, repo_type="dataset", local_dir=temp_dir)
# Build page query for page 1 of each PDF
query = await build_page_query(
local_pdf_path=local_pdf_path,
page=1,
target_longest_image_dim=1280,
image_rotation=0
)
query = await build_page_query(local_pdf_path=local_pdf_path, page=1, target_longest_image_dim=1280, image_rotation=0)
queries.append(query)
except Exception as e:
print(f"Error processing {os.path.basename(pdf_file)}: {e}")
continue
print(f"Successfully processed {len(queries)} PDFs")
return queries
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
"""Process a single prompt with image and return comparison results."""
# Track if we found the first mismatch for max_prob_first_diff
found_first_mismatch = False
max_prob_first_diff = 0.0
# Extract messages from the sample (which is the output of build_page_query)
messages = sample['messages']
messages = sample["messages"]
# Extract the text prompt and image from the messages
user_message = messages[0]
text_prompt = None
image_base64 = None
for content in user_message['content']:
if content['type'] == 'text':
text_prompt = content['text']
elif content['type'] == 'image_url':
image_url = content['image_url']['url']
for content in user_message["content"]:
if content["type"] == "text":
text_prompt = content["text"]
elif content["type"] == "image_url":
image_url = content["image_url"]["url"]
# Extract base64 data after the comma
if ',' in image_url:
image_base64 = image_url.split(',')[1]
if "," in image_url:
image_base64 = image_url.split(",")[1]
else:
image_base64 = image_url
if text_prompt is None or image_base64 is None:
raise ValueError("Failed to extract text prompt or image from messages")
# Decode the base64 image to PIL Image
image_bytes = base64.b64decode(image_base64)
image = PIL.Image.open(BytesIO(image_bytes))
print(f"\n{'='*80}")
print(f"PROMPT: {text_prompt[:100]}..." if len(text_prompt) > 100 else f"PROMPT: {text_prompt}")
print(f"IMAGE: {image.size} {image.mode}")
# Generate with vLLM
print("\n=== vLLM Generation ===")
# For VLLM, use the messages just as comes out of build_page_query
outputs = llm.chat(messages, sampling_params)
output = outputs[0]
# Extract prompt and generated token IDs
prompt_token_ids = output.prompt_token_ids
generated_token_ids = output.outputs[0].token_ids
print(f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids[:10]}..." if len(prompt_token_ids) > 10 else f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids}")
print(
f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids[:10]}..."
if len(prompt_token_ids) > 10
else f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids}"
)
print(f"Generated tokens ({len(generated_token_ids)}): {generated_token_ids}")
print(f"Generated text: {processor.decode(generated_token_ids, skip_special_tokens=True)}")
# Create input tensor from concatenated token IDs
# input_ids = torch.tensor([all_token_ids], device=device) # Not needed for HF VLM models
# HuggingFace forward pass
print("\n=== HuggingFace Forward Pass ===")
# Prepare inputs for HF model using the extracted image and text
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": text_prompt}
]
}
]
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text_prompt}]}]
hf_text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = processor(
text=[hf_text_prompt],
images=[image],
return_tensors="pt"
).to(device)
inputs = processor(text=[hf_text_prompt], images=[image], return_tensors="pt").to(device)
print("INPUTS", inputs)
@ -204,32 +183,32 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
generated_ids_tensor = torch.tensor([generated_token_ids], device=device)
inputs["input_ids"] = torch.cat([inputs["input_ids"], generated_ids_tensor], dim=1)
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
with torch.no_grad():
outputs_hf = hf_model(**inputs)
logits = outputs_hf.logits[0] # [seq_len, vocab_size]
# Token-by-token comparison
print(f"\n{'Pos':>4} {'Token ID':>8} {'Token':>20} {'Type':>8} {'vLLM Prob':>12} {'HF Argmax':>10} {'HF Prob':>12} {'Match':>6} {'HF Token':>20}")
print("-" * 125)
# Get vLLM logprobs for generated tokens
vllm_logprobs = output.outputs[0].logprobs
# Track mismatch info
first_mismatch_idx = None
# Get all token IDs from the HF model's input
all_token_ids = inputs["input_ids"][0].tolist()
# Compare ALL tokens (prompt + generated)
for pos, token_id in enumerate(all_token_ids):
token_str = processor.decode([token_id], skip_special_tokens=False).replace('\n', '\\n').replace('\r', '\\r')
token_str = processor.decode([token_id], skip_special_tokens=False).replace("\n", "\\n").replace("\r", "\\r")
# Determine if this is a prompt or generated token
is_prompt = pos < len(prompt_token_ids)
token_type = "prompt" if is_prompt else "gen"
# vLLM probability (only for generated tokens)
vllm_prob_str = "N/A"
vllm_prob = None
@ -242,17 +221,17 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
# Convert logprob to probability
vllm_prob = torch.exp(torch.tensor(token_logprobs[token_id].logprob)).item()
vllm_prob_str = f"{vllm_prob:12.6f}"
# HF prediction - only for generated tokens (skip prompt tokens entirely)
if pos > 0 and not is_prompt:
hf_logits_at_pos = logits[pos - 1]
hf_probs = torch.softmax(hf_logits_at_pos, dim=-1)
hf_argmax = torch.argmax(hf_logits_at_pos).item()
hf_prob = hf_probs[token_id].item()
# Check if predictions match
match = "" if token_id == hf_argmax else ""
# Track first mismatch and probability difference
if token_id != hf_argmax:
if first_mismatch_idx is None:
@ -261,21 +240,21 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
if vllm_prob is not None and not found_first_mismatch:
max_prob_first_diff = abs(vllm_prob - hf_prob)
found_first_mismatch = True
# Decode HF argmax token (only show if mismatch)
hf_token_str = ""
if token_id != hf_argmax:
hf_token_str = processor.decode([hf_argmax], skip_special_tokens=False).replace('\n', '\\n').replace('\r', '\\r')
hf_token_str = processor.decode([hf_argmax], skip_special_tokens=False).replace("\n", "\\n").replace("\r", "\\r")
print(f"{pos:>4} {token_id:>8} {token_str:>20} {token_type:>8} {vllm_prob_str:>12} {hf_argmax:>10} {hf_prob:>12.6f} {match:>6} {hf_token_str:>20}")
else:
# Prompt tokens or first token - no HF comparison
print(f"{pos:>4} {token_id:>8} {token_str:>20} {token_type:>8} {vllm_prob_str:>12} {'':>10} {'':>12} {'':>6} {'':<20}")
# Summary
print(f"\n=== Summary ===")
print(f"Total tokens generated: {len(generated_token_ids)}")
# Calculate match rate
matches = 0
for i, token_id in enumerate(generated_token_ids):
@ -284,39 +263,33 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
hf_argmax = torch.argmax(hf_logits_at_pos).item()
if token_id == hf_argmax:
matches += 1
match_rate = matches / len(generated_token_ids) * 100 if generated_token_ids else 0
print(f"Token match rate: {matches}/{len(generated_token_ids)} ({match_rate:.1f}%)")
# Report first mismatch index
if first_mismatch_idx is not None:
print(f"First mismatch at generation index: {first_mismatch_idx}")
print(f"First mismatch probability difference: {max_prob_first_diff:.6f}")
else:
print("No mismatches found in generated tokens")
return {
'first_mismatch_idx': first_mismatch_idx,
'max_prob_first_diff': max_prob_first_diff,
'match_rate': match_rate,
'num_generated': len(generated_token_ids)
"first_mismatch_idx": first_mismatch_idx,
"max_prob_first_diff": max_prob_first_diff,
"match_rate": match_rate,
"num_generated": len(generated_token_ids),
}
async def async_main():
parser = argparse.ArgumentParser(description="Batch compare VLM inference between vLLM and HuggingFace")
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct",
help="Model name or path")
parser.add_argument("--max-tokens", type=int, default=20,
help="Maximum tokens to generate per prompt")
parser.add_argument("--temperature", type=float, default=0.0,
help="Sampling temperature")
parser.add_argument("--num-prompts", type=int, default=100,
help="Number of prompts to load from WildVision")
parser.add_argument("--prob-threshold", type=float, default=0.20,
help="Probability difference threshold to stop")
parser.add_argument("--seed", type=int, default=42,
help="Random seed for prompt selection")
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct", help="Model name or path")
parser.add_argument("--max-tokens", type=int, default=20, help="Maximum tokens to generate per prompt")
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
parser.add_argument("--num-prompts", type=int, default=100, help="Number of prompts to load from WildVision")
parser.add_argument("--prob-threshold", type=float, default=0.20, help="Probability difference threshold to stop")
parser.add_argument("--seed", type=int, default=42, help="Random seed for prompt selection")
args = parser.parse_args()
print(f"Model: {args.model}")
@ -330,72 +303,63 @@ async def async_main():
# Load prompts and images
samples = await load_pdf_prompts(num_samples=args.num_prompts, seed=args.seed)
# Load HuggingFace model and processor first
print("\n=== Loading HuggingFace Model ===")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor_hf = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
hf_model = AutoModelForVision2Seq.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
hf_model = AutoModelForVision2Seq.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
hf_model.eval()
# Create vLLM engine once
print("\n=== Creating vLLM Engine ===")
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=args.max_tokens,
logprobs=1 # Get top-1 logprobs
)
sampling_params = SamplingParams(temperature=args.temperature, max_tokens=args.max_tokens, logprobs=1) # Get top-1 logprobs
# Process samples until finding significant mismatch
print("\n=== Processing Samples ===")
# Initialize statistics tracking
all_results = []
for i, sample in enumerate(samples):
print(f"\n\n{'#'*80}")
print(f"### Processing sample {i+1}/{len(samples)}")
print(f"{'#'*80}")
# Process single sample
result = process_single_prompt(sample, llm, hf_model, processor_hf, sampling_params, device, args)
all_results.append(result)
# Check if we found significant mismatch
if result['first_mismatch_idx'] is not None and result['max_prob_first_diff'] > args.prob_threshold:
if result["first_mismatch_idx"] is not None and result["max_prob_first_diff"] > args.prob_threshold:
print(f"\n\n{'*'*80}")
print(f"*** FOUND SIGNIFICANT MISMATCH ***")
print(f"*** First mismatch probability difference: {result['max_prob_first_diff']:.6f} > {args.prob_threshold} ***")
print(f"*** Stopping after sample {i+1}/{len(samples)} ***")
print(f"{'*'*80}")
# Report aggregated statistics
print(f"\n\n{'='*80}")
print("=== AGGREGATED STATISTICS ===")
print(f"{'='*80}")
total_samples = len(all_results)
samples_with_mismatches = sum(1 for r in all_results if r['first_mismatch_idx'] is not None)
total_tokens_generated = sum(r['num_generated'] for r in all_results)
samples_with_mismatches = sum(1 for r in all_results if r["first_mismatch_idx"] is not None)
total_tokens_generated = sum(r["num_generated"] for r in all_results)
print(f"Total samples processed: {total_samples}")
print(f"Samples with mismatches: {samples_with_mismatches} ({samples_with_mismatches/total_samples*100:.1f}%)")
print(f"Total tokens generated: {total_tokens_generated}")
if samples_with_mismatches > 0:
avg_match_rate = sum(r['match_rate'] for r in all_results) / total_samples
max_prob_diffs = [r['max_prob_first_diff'] for r in all_results if r['first_mismatch_idx'] is not None]
avg_match_rate = sum(r["match_rate"] for r in all_results) / total_samples
max_prob_diffs = [r["max_prob_first_diff"] for r in all_results if r["first_mismatch_idx"] is not None]
avg_prob_diff = sum(max_prob_diffs) / len(max_prob_diffs)
max_prob_diff_overall = max(max_prob_diffs)
first_mismatch_positions = [r['first_mismatch_idx'] for r in all_results if r['first_mismatch_idx'] is not None]
first_mismatch_positions = [r["first_mismatch_idx"] for r in all_results if r["first_mismatch_idx"] is not None]
avg_first_mismatch_pos = sum(first_mismatch_positions) / len(first_mismatch_positions)
print(f"\nMismatch Statistics:")
print(f" Average token match rate: {avg_match_rate:.1f}%")
print(f" Average first mismatch position: {avg_first_mismatch_pos:.1f}")
@ -403,7 +367,7 @@ async def async_main():
print(f" Max first mismatch prob diff: {max_prob_diff_overall:.6f}")
else:
print("\nNo mismatches found in any samples!")
print(f"\n{'='*80}")
@ -412,4 +376,4 @@ def main():
if __name__ == "__main__":
main()
main()

View File

@ -7,7 +7,7 @@ Compresses OlmOCR checkpoints using FP8 quantization:
Usage:
python compress_checkpoint.py <source_path> <destination_path> --recipe <recipe_path> [--num-calibration-samples N] [--calibration-pdfs PDF1+PDF2+...]
source_path: Path to checkpoint (local or S3)
destination_path: Where to save compressed checkpoint (local or S3)
recipe_path: Path to quantization config YAML file
@ -26,50 +26,54 @@ import shutil
import tempfile
from io import BytesIO
from pathlib import Path
from typing import Optional, Tuple, Union, List
from typing import List, Optional, Tuple, Union
import boto3
import torch
from datasets import Dataset
from llmcompressor import oneshot
from PIL import Image
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import (
AutoProcessor,
AutoTokenizer,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
)
from olmocr.s3_utils import parse_s3_path
from olmocr.pipeline import build_page_query
from olmocr.s3_utils import parse_s3_path
s3_client = boto3.client("s3")
def get_calibration_pdfs(num_samples: int, pdf_paths: List[str]) -> List[str]:
"""Get calibration PDFs from provided paths.
Args:
num_samples: Number of samples to use
pdf_paths: List of local PDF paths
Returns:
List of valid PDF paths
"""
print(f"Using {len(pdf_paths)} provided calibration PDFs")
# If more PDFs provided than needed, randomly sample
if len(pdf_paths) > num_samples:
pdf_paths = random.sample(pdf_paths, num_samples)
print(f"Randomly sampled {num_samples} PDFs from provided paths")
# Verify all PDFs exist
valid_paths = []
for path in pdf_paths:
if os.path.exists(path) and path.endswith('.pdf'):
if os.path.exists(path) and path.endswith(".pdf"):
valid_paths.append(path)
else:
print(f" Warning: Skipping invalid path: {path}")
if not valid_paths:
raise ValueError("No valid PDF paths found in the provided list")
print(f"Using {len(valid_paths)} valid calibration PDFs")
return valid_paths
@ -77,14 +81,14 @@ def get_calibration_pdfs(num_samples: int, pdf_paths: List[str]) -> List[str]:
async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Dataset:
"""Prepare calibration dataset from PDFs using build_page_query."""
dataset_items = []
for pdf_path in pdf_paths:
# Get first page of each PDF (page 0)
query = await build_page_query(pdf_path, page=0, target_longest_image_dim=1024)
# Extract the messages
messages = query["messages"]
# Extract images from the message content
images = []
for message in messages:
@ -99,12 +103,10 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Datase
image_bytes = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_bytes))
images.append(image)
# Apply chat template to get the text
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Process with tokenizer
inputs = processor(
text=[text],
@ -113,31 +115,32 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Datase
max_length=8192,
truncation=True,
)
dataset_items.append(inputs)
# Convert list of dicts to HuggingFace Dataset
if dataset_items:
# Create dataset in batches to avoid overflow
batch_size = 50 # Process in smaller batches
all_datasets = []
for i in range(0, len(dataset_items), batch_size):
batch = dataset_items[i:i + batch_size]
batch = dataset_items[i : i + batch_size]
# Flatten the batch into a dict of lists
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [item[key] for item in batch]
# Create dataset for this batch
batch_dataset = Dataset.from_dict(batch_dict)
all_datasets.append(batch_dataset)
# Concatenate all batch datasets
if len(all_datasets) == 1:
return all_datasets[0]
else:
from datasets import concatenate_datasets
return concatenate_datasets(all_datasets)
else:
return Dataset.from_dict({})
@ -151,21 +154,21 @@ def is_s3_path(path: str) -> bool:
def download_s3_to_local(bucket: str, prefix: str, local_dir: str) -> None:
"""Download all files from S3 prefix to local directory."""
os.makedirs(local_dir, exist_ok=True)
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
print(f"Downloading checkpoint from s3://{bucket}/{prefix} to {local_dir}...")
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if key.endswith("/"):
continue
rel_path = os.path.relpath(key, prefix)
local_path = os.path.join(local_dir, rel_path)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
s3_client.download_file(bucket, key, local_path)
print(f" Downloaded {rel_path}")
@ -174,18 +177,20 @@ def download_s3_to_local(bucket: str, prefix: str, local_dir: str) -> None:
def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
"""Upload all files from local directory to S3."""
print(f"Uploading compressed checkpoint from {local_dir} to s3://{bucket}/{prefix}...")
for root, _, files in os.walk(local_dir):
for file in files:
local_path = os.path.join(root, file)
rel_path = os.path.relpath(local_path, local_dir)
s3_key = os.path.join(prefix, rel_path)
s3_client.upload_file(local_path, bucket, s3_key)
print(f" Uploaded {rel_path}")
def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
def load_model_and_tokenizer(
source_path: str,
) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
"""Load model and tokenizer from source path (local or S3)."""
if is_s3_path(source_path):
# Download from S3 to temporary directory
@ -196,53 +201,37 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditio
else:
model_path = source_path
temp_dir = None
# Read config to determine model architecture
config_path = os.path.join(model_path, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
# Get model name from config
model_name = config.get("name_or_path", "")
print(f"Loading model from {model_path}...")
# Load appropriate model class based on name
if "Qwen2.5-VL" in model_name:
print("Detected Qwen2.5-VL model")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
elif "Qwen2-VL" in model_name:
print("Detected Qwen2-VL model")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
else:
# Default to checking architectures list
architectures = config.get("architectures", [])
if "Qwen2_5_VLForConditionalGeneration" in architectures:
print("Detected Qwen2.5-VL model from architectures")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
else:
print("Detected Qwen2-VL model from architectures")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
print(f"Loading tokenizer from {model_path}...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer, temp_dir
@ -250,10 +239,10 @@ def copy_additional_files(source_path: str, dest_path: str, temp_source_dir: Opt
"""Copy additional config files that are needed but not saved by save_pretrained."""
# List of additional files to copy if they exist
additional_files = ["preprocessor_config.json", "chat_template.json"]
# Determine the actual source path (could be temp dir if downloaded from S3)
actual_source = temp_source_dir if temp_source_dir else source_path
for filename in additional_files:
source_file = os.path.join(actual_source, filename)
if os.path.exists(source_file):
@ -268,55 +257,50 @@ def data_collator(batch):
return {key: torch.tensor(value) for key, value in batch[0].items()}
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 512, calibration_pdfs: Optional[List[str]] = None) -> None:
def compress_checkpoint(
source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 512, calibration_pdfs: Optional[List[str]] = None
) -> None:
"""Compress OlmOCR checkpoint using FP8 quantization."""
# Load model and tokenizer
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
try:
# Print all model tensor names
print("\n=== Model Tensor Names ===")
for name, param in model.named_parameters():
print(f"{name}: shape={list(param.shape)}, dtype={param.dtype}")
print("=========================\n")
# Prepare calibration dataset if requested
dataset = None
if num_calibration_samples > 0:
if not calibration_pdfs:
raise ValueError("Calibration PDFs must be provided when num_calibration_samples > 0. Use --calibration-pdfs argument.")
print(f"\nPreparing calibration dataset with {num_calibration_samples} samples...")
# Load processor for the model
processor = AutoProcessor.from_pretrained(source_path if not temp_source_dir else temp_source_dir)
# Get calibration PDFs from provided paths
pdf_paths = get_calibration_pdfs(num_calibration_samples, calibration_pdfs)
# Prepare dataset
dataset = asyncio.run(prepare_calibration_dataset(pdf_paths, processor))
print(f"✓ Prepared {len(dataset)} calibration samples")
# Apply quantization using provided recipe
print(f"\nApplying quantization using recipe: {recipe_path}")
if dataset:
oneshot(
model=model,
recipe=recipe_path,
dataset=dataset,
max_seq_length=8192,
num_calibration_samples=len(dataset),
data_collator=data_collator
)
oneshot(model=model, recipe=recipe_path, dataset=dataset, max_seq_length=8192, num_calibration_samples=len(dataset), data_collator=data_collator)
else:
oneshot(model=model, recipe=recipe_path)
print("✓ Quantization completed successfully")
# Save the compressed model
if is_s3_path(dest_path):
# Save to temporary directory first, then upload to S3
@ -324,10 +308,10 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_
print(f"\nSaving compressed model to temporary directory...")
model.save_pretrained(temp_dest_dir)
tokenizer.save_pretrained(temp_dest_dir)
# Copy additional files
copy_additional_files(source_path, temp_dest_dir, temp_source_dir)
# Upload to S3
bucket, prefix = parse_s3_path(dest_path)
upload_local_to_s3(temp_dest_dir, bucket, prefix)
@ -337,18 +321,18 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_
os.makedirs(dest_path, exist_ok=True)
model.save_pretrained(dest_path)
tokenizer.save_pretrained(dest_path)
# Copy additional files
copy_additional_files(source_path, dest_path, temp_source_dir)
print(f"\n✓ Successfully compressed checkpoint and saved to {dest_path}")
finally:
# Clean up temporary source directory if needed
if temp_source_dir:
print(f"Cleaning up temporary directory {temp_source_dir}...")
shutil.rmtree(temp_source_dir)
# Free up GPU memory
del model
torch.cuda.empty_cache()
@ -377,48 +361,49 @@ Examples:
# Using recursive glob pattern
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/**/*.pdf"
"""
""",
)
parser.add_argument("source", help="Source checkpoint path (local or S3)")
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
parser.add_argument("--num-calibration-samples", type=int, default=512,
help="Number of calibration samples to use (default: 512, set to 0 to disable)")
parser.add_argument("--calibration-pdfs", type=str, default=None,
help="Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/*.pdf' or '/data/**/*.pdf'). Required when num-calibration-samples > 0.")
parser.add_argument("--num-calibration-samples", type=int, default=512, help="Number of calibration samples to use (default: 512, set to 0 to disable)")
parser.add_argument(
"--calibration-pdfs",
type=str,
default=None,
help="Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/*.pdf' or '/data/**/*.pdf'). Required when num-calibration-samples > 0.",
)
args = parser.parse_args()
# Parse calibration PDFs if provided
calibration_pdfs = None
if args.calibration_pdfs:
# Use pathlib for better glob handling
pattern = args.calibration_pdfs
# Handle both absolute and relative paths with recursive glob
if '**' in pattern:
if "**" in pattern:
# For recursive patterns, we need to handle them specially
if pattern.startswith('/'):
if pattern.startswith("/"):
# Absolute path with **
parts = pattern.split('**')
parts = pattern.split("**")
base_path = Path(parts[0])
glob_pattern = '**' + parts[1] if len(parts) > 1 else '**/*.pdf'
calibration_pdfs = list(base_path.glob(glob_pattern.lstrip('/')))
glob_pattern = "**" + parts[1] if len(parts) > 1 else "**/*.pdf"
calibration_pdfs = list(base_path.glob(glob_pattern.lstrip("/")))
else:
# Relative path with **
calibration_pdfs = list(Path('.').glob(pattern))
calibration_pdfs = [str(p.absolute()) for p in calibration_pdfs if p.suffix.lower() == '.pdf']
calibration_pdfs = list(Path(".").glob(pattern))
calibration_pdfs = [str(p.absolute()) for p in calibration_pdfs if p.suffix.lower() == ".pdf"]
else:
# Use standard glob for non-recursive patterns
calibration_pdfs = glob.glob(pattern)
calibration_pdfs = [p for p in calibration_pdfs if p.lower().endswith('.pdf')]
calibration_pdfs = [p for p in calibration_pdfs if p.lower().endswith(".pdf")]
print(f"Found {len(calibration_pdfs)} PDF files matching pattern: {args.calibration_pdfs}")
compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples, calibration_pdfs)
if __name__ == "__main__":
exit(main())

View File

@ -6,7 +6,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
from omegaconf import DictConfig, OmegaConf
from omegaconf import OmegaConf
@dataclass
@ -200,7 +200,7 @@ class TrainingConfig:
# Performance
dataloader_drop_last: bool = True
dataloader_num_workers: int = 16
# Data collator settings
collator_max_token_len: Optional[int] = None
remove_unused_columns: bool = False # Important for custom datasets
@ -314,10 +314,10 @@ class Config:
FrontMatterOutputFormat,
FrontMatterParser,
InstructUserMessages,
JSONOutputFormat,
LatexBracketNormalizer,
NewYamlFinetuningPromptWithAnchoring,
NewYamlFinetuningPromptWithNoAnchoring,
JSONOutputFormat,
PDFRenderer,
RandomTokenFlipper,
StaticLengthDocumentAnchoring,
@ -382,18 +382,18 @@ class Config:
if processor is None:
raise ValueError("Processor must be provided for RandomTokenFlipper step (to get valid tokens)")
tokenizer = processor.tokenizer
# Get all special token IDs to exclude
special_token_ids = set()
for token in tokenizer.all_special_tokens:
special_token_ids.add(tokenizer.convert_tokens_to_ids(token))
# Get all token IDs that are not special tokens
valid_token_ids = []
for token_id in range(len(tokenizer)):
if token_id not in special_token_ids:
valid_token_ids.append(token_id)
steps.append(
RandomTokenFlipper(
valid_token_ids=valid_token_ids,

View File

@ -1,4 +1,5 @@
import base64
import json
import logging
import re
from abc import ABC, abstractmethod
@ -8,8 +9,7 @@ from functools import reduce
from io import BytesIO
from os import PathLike
from pathlib import Path
import json
from typing import Any, Callable, Dict, List, Optional, Type, TypeAlias, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeAlias
import numpy as np
import yaml
@ -31,10 +31,10 @@ logger = logging.getLogger(__name__)
def validate_pdf_pair(md_path: Path) -> Tuple[Optional[Dict[str, Path]], Optional[Tuple[Path, str]]]:
"""Validate a single markdown-PDF pair.
Args:
md_path: Path to the markdown file
Returns:
Tuple of (valid_sample, invalid_pdf_info)
- valid_sample: Dict with markdown_path and pdf_path if valid, None otherwise
@ -42,31 +42,32 @@ def validate_pdf_pair(md_path: Path) -> Tuple[Optional[Dict[str, Path]], Optiona
"""
# Look for PDF with same stem (filename without extension)
pdf_path = md_path.with_suffix(".pdf")
if pdf_path.exists() or pdf_path.is_symlink():
# Resolve symlink if it is one
if pdf_path.is_symlink():
pdf_path = pdf_path.resolve()
# Verify the resolved path exists
if pdf_path.exists():
# Validate PDF - check it loads and has exactly one page and that you can get document-anchoring from it
try:
reader = PdfReader(str(pdf_path))
num_pages = len(reader.pages)
if num_pages != 1:
return None, (pdf_path, f"Expected 1 page, found {num_pages}")
# Test that document anchoring works
from olmocr.prompts.anchor import get_anchor_text
get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100)
return {"markdown_path": md_path, "pdf_path": pdf_path}, None
except Exception as e:
return None, (pdf_path, f"Failed to load: {str(e)}")
return None, None
@ -104,29 +105,29 @@ class BaseMarkdownPDFDataset(Dataset):
invalid_pdfs = []
logger.info(f"Validating {len(md_files)} markdown-PDF pairs using ProcessPoolExecutor...")
# Use ProcessPoolExecutor for parallel validation
with ProcessPoolExecutor(max_workers=8) as executor:
# Submit all validation tasks
future_to_md = {executor.submit(validate_pdf_pair, md_path): md_path for md_path in md_files}
# Process results as they complete
with tqdm(total=len(md_files), desc="Validating PDFs") as pbar:
for future in as_completed(future_to_md):
md_path = future_to_md[future]
try:
valid_sample, invalid_pdf_info = future.result()
if valid_sample:
self.samples.append(valid_sample)
valid_count += 1
elif invalid_pdf_info:
invalid_pdfs.append(invalid_pdf_info)
except Exception as e:
logger.error(f"Error processing {md_path}: {str(e)}")
invalid_pdfs.append((md_path.with_suffix(".pdf"), f"Processing error: {str(e)}"))
pbar.update(1)
logger.info(f"Found {valid_count} valid markdown-PDF pairs")
@ -205,11 +206,11 @@ class FrontMatterParser(PipelineStep):
value = front_matter_dict[field_name]
# Handle type conversions
if field_type == int and isinstance(value, str):
if field_type is int and isinstance(value, str):
kwargs[field_name] = int(value)
elif field_type == bool and isinstance(value, str):
elif field_type is bool and isinstance(value, str):
kwargs[field_name] = value.lower() == "true"
elif field_type == Optional[str]:
elif field_type is Optional[str]:
kwargs[field_name] = value if value else None
else:
kwargs[field_name] = value
@ -288,7 +289,7 @@ class FinetuningPrompt(PipelineStep):
def __call__(self, sample: Sample) -> Sample:
sample["instruction_prompt"] = build_finetuning_prompt(sample["anchor_text"])
return sample
@dataclass(frozen=True, slots=True)
class NewYamlFinetuningPromptWithAnchoring(PipelineStep):
@ -323,7 +324,7 @@ class FrontMatterOutputFormat(PipelineStep):
def __call__(self, sample: Sample) -> Sample:
page_data = sample["page_data"]
assert type(page_data) == PageResponse
assert type(page_data) is PageResponse
sample["response"] = (
f"""---
@ -346,58 +347,63 @@ class JSONOutputFormat(PipelineStep):
def __call__(self, sample: Sample) -> Sample:
page_data = sample["page_data"]
assert type(page_data) == PageResponse
assert type(page_data) is PageResponse
sample["response"] = json.dumps({
"primary_language": page_data.primary_language,
"is_rotation_valid": page_data.is_rotation_valid,
"rotation_correction": page_data.rotation_correction,
"is_table": page_data.is_table,
"is_diagram": page_data.is_diagram,
"natural_text": page_data.natural_text
}, ensure_ascii=False)
sample["response"] = json.dumps(
{
"primary_language": page_data.primary_language,
"is_rotation_valid": page_data.is_rotation_valid,
"rotation_correction": page_data.rotation_correction,
"is_table": page_data.is_table,
"is_diagram": page_data.is_diagram,
"natural_text": page_data.natural_text,
},
ensure_ascii=False,
)
return sample
@dataclass(frozen=True, slots=True)
class LatexBracketNormalizer(PipelineStep):
"""Normalizes LaTeX brackets in natural text field."""
def __call__(self, sample: Sample) -> Sample:
"""Normalize LaTeX brackets in the natural text field."""
# Get the page_data object
if "page_data" not in sample:
return sample
page_data = sample["page_data"]
if not hasattr(page_data, "natural_text") or not page_data.natural_text:
return sample
text = page_data.natural_text
# Define patterns for LaTeX normalization
# Order matters: process display math first, then inline
patterns = [
(r"\$\$(.+?)\$\$", r"\[\1\]"), # $$...$$ to \[...\]
(r"\$(.+?)\$", r"\(\1\)"), # $...$ to \(...\)
(r"\$(.+?)\$", r"\(\1\)"), # $...$ to \(...\)
]
# Apply replacements
for pattern, replacement in patterns:
text = re.sub(pattern, replacement, text, flags=re.DOTALL)
# Update the page_data with normalized text
# Since PageResponse is frozen, we need to create a new instance
from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse(
primary_language=page_data.primary_language,
is_rotation_valid=page_data.is_rotation_valid,
rotation_correction=page_data.rotation_correction,
is_table=page_data.is_table,
is_diagram=page_data.is_diagram,
natural_text=text
natural_text=text,
)
sample["page_data"] = new_page_data
return sample
@ -493,26 +499,26 @@ class Tokenizer(PipelineStep):
@dataclass(frozen=True, slots=True)
class RandomTokenFlipper(PipelineStep):
"""Randomly flips tokens in the output (non-masked) portion and masks their labels."""
valid_token_ids: List[int] # List of valid token IDs to substitute with
token_flip_rate: float = 1e-4
masking_index: int = -100
def __call__(self, sample: Sample) -> Sample:
"""Randomly flip tokens in the non-masked portion of labels."""
if "labels" not in sample or "input_ids" not in sample:
return sample
# Work with copies to avoid modifying original arrays
labels = sample["labels"].copy()
input_ids = sample["input_ids"].copy()
# Find indices where labels are not masked (i.e., output tokens)
non_masked_indices = np.where(labels != self.masking_index)[0]
if len(non_masked_indices) == 0:
return sample
# For each non-masked token, independently decide whether to flip
for idx in non_masked_indices:
if np.random.random() < self.token_flip_rate:
@ -520,11 +526,11 @@ class RandomTokenFlipper(PipelineStep):
random_token = np.random.choice(self.valid_token_ids)
input_ids[idx] = random_token
labels[idx] = self.masking_index
# Update sample with modified arrays
sample["input_ids"] = input_ids
sample["labels"] = labels
return sample
@ -590,26 +596,27 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# Import config module
from olmocr.train.config import Config
# Load configuration
print(f"\n=== Loading configuration from {args.config} ===")
config = Config.from_yaml(args.config)
# Validate configuration
try:
config.validate()
except ValueError as e:
print(f"Configuration validation failed: {e}")
exit(1)
# Load processor for tokenization
print(f"\nLoading processor: {config.model.name}")
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained(config.model.name)
# Select dataset based on type
if args.dataset_type == "train":
dataset_configs = config.dataset.train
@ -617,19 +624,19 @@ if __name__ == "__main__":
else:
dataset_configs = config.dataset.eval
dataset_name = "eval"
if args.dataset_index >= len(dataset_configs):
print(f"Error: Dataset index {args.dataset_index} out of range. Only {len(dataset_configs)} {dataset_name} datasets available.")
exit(1)
dataset_cfg = dataset_configs[args.dataset_index]
root_dir = dataset_cfg["root_dir"]
pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
print(f"\n=== Testing {dataset_name} dataset {args.dataset_index} ===")
print(f"Root directory: {root_dir}")
print(f"Pipeline steps: {[step.__class__.__name__ for step in pipeline_steps]}")
# Create dataset
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
@ -641,7 +648,7 @@ if __name__ == "__main__":
for i in range(min(5, len(dataset))):
sample = dataset.samples[i]
print(f" {i}: MD: {sample['markdown_path'].name}, PDF: {sample['pdf_path'].name}")
# Check if sample index is valid
if args.sample_index >= len(dataset):
print(f"\nError: Sample index {args.sample_index} out of range. Only {len(dataset)} samples available.")
@ -650,39 +657,39 @@ if __name__ == "__main__":
# Get the requested sample
print(f"\n=== Displaying sample {args.sample_index} ===")
sample = dataset[args.sample_index]
# Display sample information based on pipeline output
print("\nSample keys:", list(sample.keys()))
# If it's raw data (no tokenization)
if 'markdown_path' in sample:
if "markdown_path" in sample:
print(f"\nMarkdown file: {sample['markdown_path'].name}")
if 'pdf_path' in sample:
if "pdf_path" in sample:
print(f"PDF file: {sample['pdf_path'].name}")
if 'image' in sample and hasattr(sample['image'], 'size'):
if "image" in sample and hasattr(sample["image"], "size"):
print(f"Image size: {sample['image'].size}")
if 'page_data' in sample:
if "page_data" in sample:
print(f"\nPage data: {sample['page_data']}")
if 'messages' in sample:
if "messages" in sample:
print(f"\n=== Messages ===")
for i, msg in enumerate(sample['messages']):
for i, msg in enumerate(sample["messages"]):
print(f"\nMessage {i}:")
print(f" Role: {msg['role']}")
print(f" Content preview: {str(msg['content'])[:200]}...")
# If it's tokenized data
if 'input_ids' in sample:
if "input_ids" in sample:
print(f"\n=== Tokenized Output ===")
print(f" Keys: {list(sample.keys())}")
print(f" Input IDs shape: {sample['input_ids'].shape}")
print(f" Labels shape: {sample['labels'].shape}")
print(f" Attention mask shape: {sample['attention_mask'].shape}")
if "pixel_values" in sample:
print(f" Pixel values shape: {sample['pixel_values'].shape}")
if "image_grid_thw" in sample:
print(f" Image grid THW: {sample['image_grid_thw']}")
# Show label masking
print(f"\nLabel masking analysis:")
labels = sample["labels"]
@ -691,29 +698,29 @@ if __name__ == "__main__":
print(f" Total tokens: {total_count}")
print(f" Masked tokens: {masked_count} ({masked_count/total_count*100:.1f}%)")
print(f" Unmasked tokens: {total_count - masked_count} ({(total_count - masked_count)/total_count*100:.1f}%)")
# Find the transition point
transition_idx = None
for i in range(len(labels) - 1):
if labels[i] == -100 and labels[i + 1] != -100:
transition_idx = i + 1
break
if transition_idx:
print(f" Transition from masked to unmasked at position: {transition_idx}")
# Print all tokens
input_ids = sample["input_ids"]
print(f"\nAll tokens ({len(input_ids)} total):")
print("Format: [index] Token (repr) | Label | Token ID")
print("-" * 80)
for i in range(len(input_ids)):
token = processor.tokenizer.decode([input_ids[i]])
token_repr = repr(token)
label = labels[i] if i < len(labels) else "N/A"
token_id = input_ids[i]
# Mark special positions
marker = ""
if transition_idx and i == transition_idx:
@ -722,64 +729,65 @@ if __name__ == "__main__":
marker = " <-- START"
elif label != -100 and i > 0 and labels[i - 1] == -100:
marker = " <-- response begins"
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
# Calculate and show token statistics after the table
print(f"\nToken statistics:")
# Count consecutive high-value tokens that represent the image
# Qwen uses tokens like 151859, 151860, etc. for image patches
image_token_threshold = 151000 # Typical threshold for Qwen image tokens
image_token_count = np.sum(input_ids > image_token_threshold)
# Calculate prompt tokens (everything masked)
prompt_token_count = masked_count
# Calculate output tokens (everything not masked)
output_token_count = total_count - masked_count
# Calculate non-image prompt tokens
non_image_prompt_tokens = prompt_token_count - image_token_count
print(f" Image tokens: {image_token_count}")
print(f" Prompt tokens (total): {prompt_token_count}")
print(f" Prompt tokens (non-image): {non_image_prompt_tokens}")
print(f" Output tokens: {output_token_count}")
print(f" Total sequence length: {total_count}")
# Analyze token length distribution across entire dataset
if 'input_ids' in sample:
if "input_ids" in sample:
print(f"\n\n=== Analyzing token length distribution across entire dataset ===")
print(f"Processing {len(dataset)} samples...")
# Function to process a single sample
def process_sample(idx):
try:
current_sample = dataset[idx]
if 'labels' in current_sample:
if "labels" in current_sample:
# Count total sequence length (all tokens, prompt + completion)
labels = current_sample['labels']
labels = current_sample["labels"]
total_length = len(labels)
return (idx, total_length, None)
return (idx, None, "No labels in sample")
except Exception as e:
return (idx, None, str(e))
# Process samples in parallel with progress bar
sequence_lengths = []
max_sequence_length = 0
max_sequence_sample_idx = 0
errors = []
# Determine number of workers (use fewer workers to avoid memory issues)
import multiprocessing
num_workers = min(multiprocessing.cpu_count() // 2, 8)
with ProcessPoolExecutor(max_workers=num_workers) as executor:
# Submit all tasks
futures = {executor.submit(process_sample, idx): idx for idx in range(len(dataset))}
# Process results with progress bar
with tqdm(total=len(dataset), desc="Analyzing samples") as pbar:
for future in as_completed(futures):
@ -796,16 +804,16 @@ if __name__ == "__main__":
except Exception as e:
errors.append((idx, f"Future error: {e}"))
pbar.update(1)
if errors:
print(f"\nEncountered {len(errors)} errors during processing")
if len(errors) <= 5:
for idx, error in errors:
print(f" Sample {idx}: {error}")
if sequence_lengths:
sequence_lengths = np.array(sequence_lengths)
print(f"\nTotal sequence length statistics (prompt + completion):")
print(f" Total samples analyzed: {len(sequence_lengths)}")
print(f" Max sequence length: {max_sequence_length} tokens (sample index: {max_sequence_sample_idx})")
@ -813,37 +821,37 @@ if __name__ == "__main__":
print(f" Mean sequence length: {np.mean(sequence_lengths):.1f} tokens")
print(f" Median sequence length: {np.median(sequence_lengths):.1f} tokens")
print(f" Std dev: {np.std(sequence_lengths):.1f} tokens")
# Create histogram with 100-token buckets
print(f"\nSequence length histogram (100-token buckets):")
# Define buckets
bucket_size = 100
max_bucket = ((max_sequence_length // bucket_size) + 1) * bucket_size
buckets = list(range(0, max_bucket + bucket_size, bucket_size))
# Count samples in each bucket
hist, _ = np.histogram(sequence_lengths, bins=buckets)
# Find max count for scaling
max_count = max(hist)
bar_width = 50 # Width of histogram bars
print(f"\n{'Range':>15} | {'Count':>6} | Distribution")
print("-" * 80)
for i in range(len(hist)):
start = buckets[i]
end = buckets[i + 1] - 1
count = hist[i]
# Create bar
if max_count > 0:
bar_length = int((count / max_count) * bar_width)
bar = "" * bar_length
else:
bar = ""
range_str = f"{start:>5}-{end:>5}"
print(f"{range_str:>15} | {count:>6} | {bar}")

View File

@ -9,22 +9,22 @@ Supports model souping (averaging weights of multiple checkpoints).
Usage:
python prepare_olmocr_checkpoint.py <source_path> <destination_path>
source_path: Path to checkpoint (local or S3)
destination_path: Where to save prepared checkpoint (local or S3)
For souping multiple checkpoints:
python prepare_olmocr_checkpoint.py <source1> <source2> ... <destination>
This will average the weights of all sources and prepare the souped checkpoint.
Examples:
# Single local to local
python prepare_olmocr_checkpoint.py /path/to/checkpoint /path/to/output
# Souping multiple S3 to S3
python prepare_olmocr_checkpoint.py s3://bucket/ckpt1 s3://bucket/ckpt2 s3://bucket/souped
# Mixed souping
python prepare_olmocr_checkpoint.py s3://bucket/ckpt1 /local/ckpt2 s3://bucket/souped
"""
@ -38,9 +38,9 @@ import tempfile
import boto3
import requests
import torch
from smart_open import smart_open
from tqdm import tqdm
import torch
try:
from safetensors.torch import load_file, save_file
@ -50,32 +50,16 @@ except ImportError:
from olmocr.s3_utils import parse_s3_path
# Hugging Face model IDs for tokenizer files
HF_MODEL_IDS = {
"Qwen2VLForConditionalGeneration": "Qwen/Qwen2-VL-7B-Instruct",
"Qwen2_5_VLForConditionalGeneration": "Qwen/Qwen2.5-VL-7B-Instruct"
}
HF_MODEL_IDS = {"Qwen2VLForConditionalGeneration": "Qwen/Qwen2-VL-7B-Instruct", "Qwen2_5_VLForConditionalGeneration": "Qwen/Qwen2.5-VL-7B-Instruct"}
# Required tokenizer files to download from Hugging Face
TOKENIZER_FILES = [
"chat_template.json",
"merges.txt",
"preprocessor_config.json",
"tokenizer.json",
"tokenizer_config.json",
"vocab.json"
]
TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"]
# Supported model architectures
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
# Files to exclude from copying (training-related files)
EXCLUDED_FILES = {
"optimizer.pt",
"scheduler.pt",
"rng_state.pth",
"trainer_state.json",
"training_args.bin"
}
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin"}
s3_client = boto3.client("s3")
@ -89,34 +73,34 @@ def download_file_from_hf(filename: str, destination_dir: str, hf_base_url: str)
"""Download a file from Hugging Face model repository."""
url = f"{hf_base_url}/{filename}"
local_path = os.path.join(destination_dir, filename)
print(f"Downloading {filename} from Hugging Face...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(local_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded {filename}")
def detect_checkpoint_architecture(config_path: str) -> str:
"""Detect and validate the checkpoint architecture."""
print(f"Detecting checkpoint architecture from {config_path}...")
with smart_open(config_path, "r") as f:
config_data = json.load(f)
architectures = config_data.get("architectures", [])
# Find the supported architecture
detected_architecture = None
for arch in architectures:
if arch in SUPPORTED_ARCHITECTURES:
detected_architecture = arch
break
if not detected_architecture:
# Try to detect from model name
model_name = config_data.get("name_or_path", "")
@ -125,11 +109,8 @@ def detect_checkpoint_architecture(config_path: str) -> str:
elif "Qwen2-VL" in model_name:
detected_architecture = "Qwen2VLForConditionalGeneration"
else:
raise ValueError(
f"No supported architecture found. Expected one of {SUPPORTED_ARCHITECTURES} "
f"but found: {architectures}"
)
raise ValueError(f"No supported architecture found. Expected one of {SUPPORTED_ARCHITECTURES} " f"but found: {architectures}")
print(f"✓ Detected architecture: {detected_architecture}")
return detected_architecture
@ -137,7 +118,7 @@ def detect_checkpoint_architecture(config_path: str) -> str:
def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
"""Copy files from local directory to local directory."""
os.makedirs(dest_dir, exist_ok=True)
# Get list of files to copy
files_to_copy = []
for root, _, files in os.walk(source_dir):
@ -148,9 +129,9 @@ def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
src_path = os.path.join(root, file)
rel_path = os.path.relpath(src_path, source_dir)
files_to_copy.append((src_path, os.path.join(dest_dir, rel_path)))
print(f"Copying {len(files_to_copy)} files from {source_dir} to {dest_dir}...")
for src_path, dst_path in tqdm(files_to_copy, desc="Copying files"):
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
shutil.copy2(src_path, dst_path)
@ -170,35 +151,32 @@ def upload_file_to_s3(local_path: str, bucket: str, key: str) -> None:
def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> None:
"""Copy files from S3 to local directory."""
os.makedirs(dest_dir, exist_ok=True)
# List all objects in source
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=source_bucket, Prefix=source_prefix)
download_tasks = []
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if key.endswith("/"):
continue
filename = os.path.basename(key)
if filename in EXCLUDED_FILES:
print(f"Skipping excluded file: {filename}")
continue
rel_path = os.path.relpath(key, source_prefix)
local_path = os.path.join(dest_dir, rel_path)
download_tasks.append((source_bucket, key, local_path))
print(f"Downloading {len(download_tasks)} files from s3://{source_bucket}/{source_prefix} to {dest_dir}...")
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [
executor.submit(download_file_from_s3, bucket, key, local_path)
for bucket, key, local_path in download_tasks
]
futures = [executor.submit(download_file_from_s3, bucket, key, local_path) for bucket, key, local_path in download_tasks]
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Downloading"):
future.result()
@ -216,15 +194,12 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non
rel_path = os.path.relpath(local_path, source_dir)
s3_key = os.path.join(dest_prefix, rel_path)
upload_tasks.append((local_path, dest_bucket, s3_key))
print(f"Uploading {len(upload_tasks)} files from {source_dir} to s3://{dest_bucket}/{dest_prefix}...")
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [
executor.submit(upload_file_to_s3, local_path, bucket, key)
for local_path, bucket, key in upload_tasks
]
futures = [executor.submit(upload_file_to_s3, local_path, bucket, key) for local_path, bucket, key in upload_tasks]
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Uploading"):
future.result()
@ -234,26 +209,26 @@ def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest
# List all objects in source
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=source_bucket, Prefix=source_prefix)
copy_tasks = []
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if key.endswith("/"):
continue
filename = os.path.basename(key)
if filename in EXCLUDED_FILES:
print(f"Skipping excluded file: {filename}")
continue
rel_path = os.path.relpath(key, source_prefix)
dest_key = os.path.join(dest_prefix, rel_path)
copy_source = {"Bucket": source_bucket, "Key": key}
copy_tasks.append((copy_source, dest_bucket, dest_key))
print(f"Copying {len(copy_tasks)} files from s3://{source_bucket}/{source_prefix} to s3://{dest_bucket}/{dest_prefix}...")
for copy_source, bucket, key in tqdm(copy_tasks, desc="Copying"):
s3_client.copy_object(CopySource=copy_source, Bucket=bucket, Key=key)
@ -350,10 +325,10 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
souped_path = os.path.join(souped_dir, rel_path)
os.makedirs(os.path.dirname(souped_path), exist_ok=True)
if file_path.endswith('.safetensors'):
sum_state = load_file(file_path, device='cpu')
if file_path.endswith(".safetensors"):
sum_state = load_file(file_path, device="cpu")
for other_path in all_paths[1:]:
other_state = load_file(other_path, device='cpu')
other_state = load_file(other_path, device="cpu")
if set(sum_state.keys()) != set(other_state.keys()):
raise ValueError(f"Key mismatch in {rel_path}")
for k in sum_state:
@ -363,10 +338,10 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
for k in sum_state:
sum_state[k] /= n
save_file(sum_state, souped_path)
elif file_path.endswith('.bin'):
sum_state = torch.load(file_path, map_location='cpu')
elif file_path.endswith(".bin"):
sum_state = torch.load(file_path, map_location="cpu")
for other_path in all_paths[1:]:
other_state = torch.load(other_path, map_location='cpu')
other_state = torch.load(other_path, map_location="cpu")
if set(sum_state.keys()) != set(other_state.keys()):
raise ValueError(f"Key mismatch in {rel_path}")
for k in sum_state:
@ -390,19 +365,16 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
# Download tokenizer files from Hugging Face
print("\nDownloading tokenizer files from Hugging Face...")
if is_s3_path(dest_path):
# Download to temp directory first, then upload to S3
with tempfile.TemporaryDirectory() as temp_dir:
# Download files
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
futures = [
executor.submit(download_file_from_hf, filename, temp_dir, hf_base_url)
for filename in TOKENIZER_FILES
]
futures = [executor.submit(download_file_from_hf, filename, temp_dir, hf_base_url) for filename in TOKENIZER_FILES]
for future in concurrent.futures.as_completed(futures):
future.result()
# Upload to S3
dest_bucket, dest_prefix = parse_s3_path(dest_path)
upload_tasks = []
@ -410,25 +382,19 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
local_path = os.path.join(temp_dir, filename)
s3_key = os.path.join(dest_prefix, filename)
upload_tasks.append((local_path, dest_bucket, s3_key))
print("Uploading tokenizer files to S3...")
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
futures = [
executor.submit(upload_file_to_s3, local_path, bucket, key)
for local_path, bucket, key in upload_tasks
]
futures = [executor.submit(upload_file_to_s3, local_path, bucket, key) for local_path, bucket, key in upload_tasks]
for future in concurrent.futures.as_completed(futures):
future.result()
else:
# Download directly to destination
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
futures = [
executor.submit(download_file_from_hf, filename, dest_path, hf_base_url)
for filename in TOKENIZER_FILES
]
futures = [executor.submit(download_file_from_hf, filename, dest_path, hf_base_url) for filename in TOKENIZER_FILES]
for future in concurrent.futures.as_completed(futures):
future.result()
print(f"\n✓ Successfully prepared checkpoint at {dest_path}")
@ -436,26 +402,26 @@ def main():
parser = argparse.ArgumentParser(
description="Prepare OlmOCR checkpoint for deployment",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__.split("Usage:")[1] # Use the docstring for epilog
epilog=__doc__.split("Usage:")[1], # Use the docstring for epilog
)
parser.add_argument("paths", nargs='+', help="One or more source paths followed by destination path (local or S3)")
parser.add_argument("paths", nargs="+", help="One or more source paths followed by destination path (local or S3)")
args = parser.parse_args()
if len(args.paths) < 2:
parser.error("At least one source and one destination required")
sources = args.paths[:-1]
destination = args.paths[-1]
try:
prepare_checkpoints(sources, destination)
except Exception as e:
print(f"\n❌ Error: {e}")
return 1
return 0
if __name__ == "__main__":
exit(main())
exit(main())

View File

@ -5,6 +5,7 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse
import logging
import os
from typing import Optional
import numpy as np
import torch
@ -18,7 +19,6 @@ from transformers import (
TrainingArguments,
)
from typing import Optional
from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset
@ -47,13 +47,13 @@ class QwenDataCollator:
input_ids = torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"]
attention_mask = torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"]
labels = torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"]
# Trim to max_token_len if specified
if self.max_token_len is not None:
input_ids = input_ids[:self.max_token_len]
attention_mask = attention_mask[:self.max_token_len]
labels = labels[:self.max_token_len]
input_ids = input_ids[: self.max_token_len]
attention_mask = attention_mask[: self.max_token_len]
labels = labels[: self.max_token_len]
batch["input_ids"].append(input_ids)
batch["attention_mask"].append(attention_mask)
batch["labels"].append(labels)
@ -103,7 +103,7 @@ def main():
if config.project_name:
os.environ["WANDB_PROJECT"] = config.project_name
logger.info(f"Setting WANDB_PROJECT to: {config.project_name}")
# Load processor for tokenization
logger.info(f"Loading processor: {config.model.name}")
processor = AutoProcessor.from_pretrained(
@ -209,7 +209,7 @@ def main():
adam_epsilon=config.training.adam_epsilon,
weight_decay=config.training.weight_decay,
max_grad_norm=config.training.max_grad_norm,
bf16=True, # We're sticking with this known good reduced precision option
bf16=True, # We're sticking with this known good reduced precision option
eval_strategy=config.training.evaluation_strategy,
eval_steps=config.training.eval_steps,
save_strategy=config.training.save_strategy,

View File

@ -167,7 +167,7 @@ reportPrivateImportUsage = false
line-length = 160
target-version = "py311"
exclude = ["olmocr/train/molmo", "tests/*"]
ignore = ["E722"] #igore bare except
ignore = ["E722", "F541"] #igore bare except, and f string without placeholders
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]