mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-13 00:09:42 +00:00
Lint fixes
This commit is contained in:
parent
5ec49672ea
commit
6e8272413c
@ -17,7 +17,7 @@ import time
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures.process import BrokenProcessPool
|
||||
from dataclasses import dataclass
|
||||
from functools import cache, partial
|
||||
from functools import cache
|
||||
from io import BytesIO
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@ -34,7 +34,6 @@ from olmocr.check import (
|
||||
check_torch_gpu_available,
|
||||
)
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.train.dataloader import FrontMatterParser
|
||||
from olmocr.filter.filter import Language, PdfFilter
|
||||
from olmocr.image_utils import convert_image_to_pdf_bytes, is_jpeg, is_png
|
||||
from olmocr.metrics import MetricsKeeper, WorkerTracker
|
||||
@ -48,6 +47,7 @@ from olmocr.s3_utils import (
|
||||
get_s3_bytes_with_backoff,
|
||||
parse_s3_path,
|
||||
)
|
||||
from olmocr.train.dataloader import FrontMatterParser
|
||||
from olmocr.version import VERSION
|
||||
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
|
||||
|
||||
@ -227,7 +227,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
|
||||
|
||||
# Enable guided decoding regex if needed
|
||||
if args.guided_decoding:
|
||||
query["guided_regex"] = r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n(?:---|---\n[\s\S]+)"
|
||||
query["guided_regex"] = (
|
||||
r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n(?:---|---\n[\s\S]+)"
|
||||
)
|
||||
|
||||
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")
|
||||
|
||||
@ -247,7 +249,7 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
|
||||
local_anchor_text_len = max(1, local_anchor_text_len // 2)
|
||||
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_orig_path}-{page_num}")
|
||||
raise ValueError("Response exceeded model_max_context, cannot use this response")
|
||||
|
||||
|
||||
if base_response_data["choices"][0]["finish_reason"] != "stop":
|
||||
local_anchor_text_len = max(1, local_anchor_text_len // 2)
|
||||
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_orig_path}-{page_num}")
|
||||
@ -329,6 +331,7 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
|
||||
is_fallback=True,
|
||||
)
|
||||
|
||||
|
||||
async def process_pdf(args, worker_id: int, pdf_orig_path: str):
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf", delete=False) as tf:
|
||||
try:
|
||||
@ -586,9 +589,9 @@ async def vllm_server_task(model_name_or_path, args, semaphore):
|
||||
|
||||
if args.gpu_memory_utilization is not None:
|
||||
cmd.extend(["--gpu-memory-utilization", str(args.gpu_memory_utilization)])
|
||||
|
||||
|
||||
if args.max_model_len is not None:
|
||||
cmd.extend(["--max-model-len", str(args.max_model_len)])
|
||||
cmd.extend(["--max-model-len", str(args.max_model_len)])
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
@ -1016,7 +1019,11 @@ async def main():
|
||||
)
|
||||
|
||||
parser.add_argument("--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve).")
|
||||
parser.add_argument("--max_model_len", type=int, help="Upper bound (tokens) vLLM will allocate KV-cache for; " "passed through to vllm serve as --max-model-len.",)
|
||||
parser.add_argument(
|
||||
"--max_model_len",
|
||||
type=int,
|
||||
help="Upper bound (tokens) vLLM will allocate KV-cache for; " "passed through to vllm serve as --max-model-len.",
|
||||
)
|
||||
|
||||
parser.add_argument("--model_max_context", type=int, default="8192", help="Maximum context length that the model was fine tuned under")
|
||||
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288)
|
||||
@ -1041,7 +1048,7 @@ async def main():
|
||||
logger.info(
|
||||
"If you run out of GPU memory during start-up or get 'KV cache is larger than available memory' errors, retry with lower values, e.g. --gpu_memory_utilization 0.80 --max_model_len 16384"
|
||||
)
|
||||
|
||||
|
||||
global workspace_s3, pdf_s3
|
||||
# set the global BASE_SERVER_PORT from args
|
||||
global BASE_SERVER_PORT
|
||||
@ -1227,12 +1234,12 @@ async def main():
|
||||
|
||||
# Output finished_on_attempt statistics
|
||||
logger.info("\nPages finished by attempt number:")
|
||||
total_finished = sum(total_metrics.get(f'finished_on_attempt_{i}', 0) for i in range(args.max_page_retries))
|
||||
total_finished = sum(total_metrics.get(f"finished_on_attempt_{i}", 0) for i in range(args.max_page_retries))
|
||||
cumulative = 0
|
||||
|
||||
|
||||
for i in range(args.max_page_retries):
|
||||
if f'finished_on_attempt_{i}' in total_metrics:
|
||||
count = total_metrics[f'finished_on_attempt_{i}']
|
||||
if f"finished_on_attempt_{i}" in total_metrics:
|
||||
count = total_metrics[f"finished_on_attempt_{i}"]
|
||||
cumulative += count
|
||||
percentage = (count / total_finished * 100) if total_finished > 0 else 0
|
||||
cumulative_percentage = (cumulative / total_finished * 100) if total_finished > 0 else 0
|
||||
@ -1253,4 +1260,4 @@ async def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from .prompts import (
|
||||
PageResponse,
|
||||
build_finetuning_prompt,
|
||||
build_openai_silver_data_prompt,
|
||||
build_no_anchoring_yaml_prompt,
|
||||
build_openai_silver_data_prompt,
|
||||
extract_raw_text,
|
||||
openai_response_format_schema,
|
||||
)
|
||||
|
||||
@ -109,9 +109,9 @@ def build_finetuning_prompt(base_text: str) -> str:
|
||||
|
||||
def build_no_anchoring_yaml_prompt() -> str:
|
||||
return (
|
||||
f"Attached is one page of a document that you must process. "
|
||||
f"Just return the plain text representation of this document as if you were reading it naturally. Convert equations to LateX and tables to markdown.\n"
|
||||
f"Return your output as markdown, with a front matter section on top specifying values for the primary_language, is_rotation_valid, rotation_correction, is_table, and is_diagram parameters."
|
||||
"Attached is one page of a document that you must process. "
|
||||
"Just return the plain text representation of this document as if you were reading it naturally. Convert equations to LateX and tables to markdown.\n"
|
||||
"Return your output as markdown, with a front matter section on top specifying values for the primary_language, is_rotation_valid, rotation_correction, is_table, and is_diagram parameters."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -6,22 +6,21 @@ Processes prompts and images from WildVision-bench until finding significant mis
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import gc
|
||||
import os
|
||||
import glob
|
||||
import tempfile
|
||||
import shutil
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoProcessor, AutoModelForVision2Seq
|
||||
from huggingface_hub import snapshot_download
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import List, Dict
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import PIL.Image
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from olmocr.pipeline import build_page_query
|
||||
from olmocr.s3_utils import download_directory
|
||||
@ -51,7 +50,7 @@ async def download_model(model_name_or_path: str, max_retries: int = 5):
|
||||
if retry == max_retries - 1:
|
||||
raise # Raise on final attempt and fail the job
|
||||
logger.warning(f"Model download failed (attempt {retry + 1}/{max_retries}), retrying...")
|
||||
await asyncio.sleep(2 ** retry) # Exponential backoff
|
||||
await asyncio.sleep(2**retry) # Exponential backoff
|
||||
|
||||
|
||||
def image_to_base64_data_url(image):
|
||||
@ -65,38 +64,35 @@ def image_to_base64_data_url(image):
|
||||
async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: int = 2048) -> List[Dict[str, str]]:
|
||||
"""Load prompts and images from olmOCR-mix-0225-benchmarkset dataset with fixed random seed."""
|
||||
print(f"Loading olmOCR-mix-0225-benchmarkset dataset with {num_samples} samples and seed {seed}")
|
||||
|
||||
|
||||
# Set random seed for reproducibility
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
# Import huggingface_hub utilities to list files
|
||||
from huggingface_hub import list_repo_files, hf_hub_download
|
||||
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
|
||||
# List all PDF files in the repository
|
||||
print("Listing PDF files in dataset...")
|
||||
all_files = list_repo_files(
|
||||
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
|
||||
repo_type="dataset"
|
||||
)
|
||||
|
||||
all_files = list_repo_files(repo_id="allenai/olmOCR-mix-0225-benchmarkset", repo_type="dataset")
|
||||
|
||||
# Filter for PDF files in the pdfs directory
|
||||
pdf_files = [f for f in all_files if f.startswith("pdfs/") and f.endswith(".pdf")]
|
||||
|
||||
|
||||
if not pdf_files:
|
||||
raise ValueError("No PDF files found in the dataset")
|
||||
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files in dataset")
|
||||
|
||||
|
||||
# Randomly sample num_samples PDFs
|
||||
if len(pdf_files) > num_samples:
|
||||
sampled_pdf_files = random.sample(pdf_files, num_samples)
|
||||
else:
|
||||
sampled_pdf_files = pdf_files
|
||||
print(f"Warning: Only {len(pdf_files)} PDFs available, less than requested {num_samples}")
|
||||
|
||||
|
||||
print(f"Sampled {len(sampled_pdf_files)} PDFs to download")
|
||||
|
||||
|
||||
# Download only the sampled PDFs and process them
|
||||
queries = []
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
@ -104,99 +100,82 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
|
||||
try:
|
||||
# Download individual PDF file
|
||||
print(f"Downloading {pdf_file}...")
|
||||
local_pdf_path = hf_hub_download(
|
||||
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
|
||||
filename=pdf_file,
|
||||
repo_type="dataset",
|
||||
local_dir=temp_dir
|
||||
)
|
||||
|
||||
local_pdf_path = hf_hub_download(repo_id="allenai/olmOCR-mix-0225-benchmarkset", filename=pdf_file, repo_type="dataset", local_dir=temp_dir)
|
||||
|
||||
# Build page query for page 1 of each PDF
|
||||
query = await build_page_query(
|
||||
local_pdf_path=local_pdf_path,
|
||||
page=1,
|
||||
target_longest_image_dim=1280,
|
||||
image_rotation=0
|
||||
)
|
||||
query = await build_page_query(local_pdf_path=local_pdf_path, page=1, target_longest_image_dim=1280, image_rotation=0)
|
||||
queries.append(query)
|
||||
except Exception as e:
|
||||
print(f"Error processing {os.path.basename(pdf_file)}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
print(f"Successfully processed {len(queries)} PDFs")
|
||||
return queries
|
||||
|
||||
|
||||
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
|
||||
"""Process a single prompt with image and return comparison results."""
|
||||
# Track if we found the first mismatch for max_prob_first_diff
|
||||
found_first_mismatch = False
|
||||
max_prob_first_diff = 0.0
|
||||
# Extract messages from the sample (which is the output of build_page_query)
|
||||
messages = sample['messages']
|
||||
|
||||
messages = sample["messages"]
|
||||
|
||||
# Extract the text prompt and image from the messages
|
||||
user_message = messages[0]
|
||||
text_prompt = None
|
||||
image_base64 = None
|
||||
|
||||
for content in user_message['content']:
|
||||
if content['type'] == 'text':
|
||||
text_prompt = content['text']
|
||||
elif content['type'] == 'image_url':
|
||||
image_url = content['image_url']['url']
|
||||
|
||||
for content in user_message["content"]:
|
||||
if content["type"] == "text":
|
||||
text_prompt = content["text"]
|
||||
elif content["type"] == "image_url":
|
||||
image_url = content["image_url"]["url"]
|
||||
# Extract base64 data after the comma
|
||||
if ',' in image_url:
|
||||
image_base64 = image_url.split(',')[1]
|
||||
if "," in image_url:
|
||||
image_base64 = image_url.split(",")[1]
|
||||
else:
|
||||
image_base64 = image_url
|
||||
|
||||
|
||||
if text_prompt is None or image_base64 is None:
|
||||
raise ValueError("Failed to extract text prompt or image from messages")
|
||||
|
||||
|
||||
# Decode the base64 image to PIL Image
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image = PIL.Image.open(BytesIO(image_bytes))
|
||||
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"PROMPT: {text_prompt[:100]}..." if len(text_prompt) > 100 else f"PROMPT: {text_prompt}")
|
||||
print(f"IMAGE: {image.size} {image.mode}")
|
||||
|
||||
|
||||
# Generate with vLLM
|
||||
print("\n=== vLLM Generation ===")
|
||||
|
||||
|
||||
# For VLLM, use the messages just as comes out of build_page_query
|
||||
outputs = llm.chat(messages, sampling_params)
|
||||
output = outputs[0]
|
||||
|
||||
|
||||
# Extract prompt and generated token IDs
|
||||
prompt_token_ids = output.prompt_token_ids
|
||||
generated_token_ids = output.outputs[0].token_ids
|
||||
|
||||
print(f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids[:10]}..." if len(prompt_token_ids) > 10 else f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids}")
|
||||
print(
|
||||
f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids[:10]}..."
|
||||
if len(prompt_token_ids) > 10
|
||||
else f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids}"
|
||||
)
|
||||
print(f"Generated tokens ({len(generated_token_ids)}): {generated_token_ids}")
|
||||
print(f"Generated text: {processor.decode(generated_token_ids, skip_special_tokens=True)}")
|
||||
|
||||
|
||||
# Create input tensor from concatenated token IDs
|
||||
# input_ids = torch.tensor([all_token_ids], device=device) # Not needed for HF VLM models
|
||||
|
||||
|
||||
# HuggingFace forward pass
|
||||
print("\n=== HuggingFace Forward Pass ===")
|
||||
# Prepare inputs for HF model using the extracted image and text
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": text_prompt}
|
||||
]
|
||||
}
|
||||
]
|
||||
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text_prompt}]}]
|
||||
hf_text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
||||
inputs = processor(
|
||||
text=[hf_text_prompt],
|
||||
images=[image],
|
||||
return_tensors="pt"
|
||||
).to(device)
|
||||
inputs = processor(text=[hf_text_prompt], images=[image], return_tensors="pt").to(device)
|
||||
|
||||
print("INPUTS", inputs)
|
||||
|
||||
@ -204,32 +183,32 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
generated_ids_tensor = torch.tensor([generated_token_ids], device=device)
|
||||
inputs["input_ids"] = torch.cat([inputs["input_ids"], generated_ids_tensor], dim=1)
|
||||
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_hf = hf_model(**inputs)
|
||||
logits = outputs_hf.logits[0] # [seq_len, vocab_size]
|
||||
|
||||
|
||||
# Token-by-token comparison
|
||||
print(f"\n{'Pos':>4} {'Token ID':>8} {'Token':>20} {'Type':>8} {'vLLM Prob':>12} {'HF Argmax':>10} {'HF Prob':>12} {'Match':>6} {'HF Token':>20}")
|
||||
print("-" * 125)
|
||||
|
||||
|
||||
# Get vLLM logprobs for generated tokens
|
||||
vllm_logprobs = output.outputs[0].logprobs
|
||||
|
||||
|
||||
# Track mismatch info
|
||||
first_mismatch_idx = None
|
||||
|
||||
|
||||
# Get all token IDs from the HF model's input
|
||||
all_token_ids = inputs["input_ids"][0].tolist()
|
||||
|
||||
|
||||
# Compare ALL tokens (prompt + generated)
|
||||
for pos, token_id in enumerate(all_token_ids):
|
||||
token_str = processor.decode([token_id], skip_special_tokens=False).replace('\n', '\\n').replace('\r', '\\r')
|
||||
|
||||
token_str = processor.decode([token_id], skip_special_tokens=False).replace("\n", "\\n").replace("\r", "\\r")
|
||||
|
||||
# Determine if this is a prompt or generated token
|
||||
is_prompt = pos < len(prompt_token_ids)
|
||||
token_type = "prompt" if is_prompt else "gen"
|
||||
|
||||
|
||||
# vLLM probability (only for generated tokens)
|
||||
vllm_prob_str = "N/A"
|
||||
vllm_prob = None
|
||||
@ -242,17 +221,17 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
# Convert logprob to probability
|
||||
vllm_prob = torch.exp(torch.tensor(token_logprobs[token_id].logprob)).item()
|
||||
vllm_prob_str = f"{vllm_prob:12.6f}"
|
||||
|
||||
|
||||
# HF prediction - only for generated tokens (skip prompt tokens entirely)
|
||||
if pos > 0 and not is_prompt:
|
||||
hf_logits_at_pos = logits[pos - 1]
|
||||
hf_probs = torch.softmax(hf_logits_at_pos, dim=-1)
|
||||
hf_argmax = torch.argmax(hf_logits_at_pos).item()
|
||||
hf_prob = hf_probs[token_id].item()
|
||||
|
||||
|
||||
# Check if predictions match
|
||||
match = "✓" if token_id == hf_argmax else "✗"
|
||||
|
||||
|
||||
# Track first mismatch and probability difference
|
||||
if token_id != hf_argmax:
|
||||
if first_mismatch_idx is None:
|
||||
@ -261,21 +240,21 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
if vllm_prob is not None and not found_first_mismatch:
|
||||
max_prob_first_diff = abs(vllm_prob - hf_prob)
|
||||
found_first_mismatch = True
|
||||
|
||||
|
||||
# Decode HF argmax token (only show if mismatch)
|
||||
hf_token_str = ""
|
||||
if token_id != hf_argmax:
|
||||
hf_token_str = processor.decode([hf_argmax], skip_special_tokens=False).replace('\n', '\\n').replace('\r', '\\r')
|
||||
|
||||
hf_token_str = processor.decode([hf_argmax], skip_special_tokens=False).replace("\n", "\\n").replace("\r", "\\r")
|
||||
|
||||
print(f"{pos:>4} {token_id:>8} {token_str:>20} {token_type:>8} {vllm_prob_str:>12} {hf_argmax:>10} {hf_prob:>12.6f} {match:>6} {hf_token_str:>20}")
|
||||
else:
|
||||
# Prompt tokens or first token - no HF comparison
|
||||
print(f"{pos:>4} {token_id:>8} {token_str:>20} {token_type:>8} {vllm_prob_str:>12} {'':>10} {'':>12} {'':>6} {'':<20}")
|
||||
|
||||
|
||||
# Summary
|
||||
print(f"\n=== Summary ===")
|
||||
print(f"Total tokens generated: {len(generated_token_ids)}")
|
||||
|
||||
|
||||
# Calculate match rate
|
||||
matches = 0
|
||||
for i, token_id in enumerate(generated_token_ids):
|
||||
@ -284,39 +263,33 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
hf_argmax = torch.argmax(hf_logits_at_pos).item()
|
||||
if token_id == hf_argmax:
|
||||
matches += 1
|
||||
|
||||
|
||||
match_rate = matches / len(generated_token_ids) * 100 if generated_token_ids else 0
|
||||
print(f"Token match rate: {matches}/{len(generated_token_ids)} ({match_rate:.1f}%)")
|
||||
|
||||
|
||||
# Report first mismatch index
|
||||
if first_mismatch_idx is not None:
|
||||
print(f"First mismatch at generation index: {first_mismatch_idx}")
|
||||
print(f"First mismatch probability difference: {max_prob_first_diff:.6f}")
|
||||
else:
|
||||
print("No mismatches found in generated tokens")
|
||||
|
||||
|
||||
return {
|
||||
'first_mismatch_idx': first_mismatch_idx,
|
||||
'max_prob_first_diff': max_prob_first_diff,
|
||||
'match_rate': match_rate,
|
||||
'num_generated': len(generated_token_ids)
|
||||
"first_mismatch_idx": first_mismatch_idx,
|
||||
"max_prob_first_diff": max_prob_first_diff,
|
||||
"match_rate": match_rate,
|
||||
"num_generated": len(generated_token_ids),
|
||||
}
|
||||
|
||||
|
||||
async def async_main():
|
||||
parser = argparse.ArgumentParser(description="Batch compare VLM inference between vLLM and HuggingFace")
|
||||
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
help="Model name or path")
|
||||
parser.add_argument("--max-tokens", type=int, default=20,
|
||||
help="Maximum tokens to generate per prompt")
|
||||
parser.add_argument("--temperature", type=float, default=0.0,
|
||||
help="Sampling temperature")
|
||||
parser.add_argument("--num-prompts", type=int, default=100,
|
||||
help="Number of prompts to load from WildVision")
|
||||
parser.add_argument("--prob-threshold", type=float, default=0.20,
|
||||
help="Probability difference threshold to stop")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed for prompt selection")
|
||||
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct", help="Model name or path")
|
||||
parser.add_argument("--max-tokens", type=int, default=20, help="Maximum tokens to generate per prompt")
|
||||
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
|
||||
parser.add_argument("--num-prompts", type=int, default=100, help="Number of prompts to load from WildVision")
|
||||
parser.add_argument("--prob-threshold", type=float, default=0.20, help="Probability difference threshold to stop")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed for prompt selection")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Model: {args.model}")
|
||||
@ -330,72 +303,63 @@ async def async_main():
|
||||
|
||||
# Load prompts and images
|
||||
samples = await load_pdf_prompts(num_samples=args.num_prompts, seed=args.seed)
|
||||
|
||||
|
||||
# Load HuggingFace model and processor first
|
||||
print("\n=== Loading HuggingFace Model ===")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
processor_hf = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
hf_model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
hf_model = AutoModelForVision2Seq.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
|
||||
hf_model.eval()
|
||||
|
||||
|
||||
# Create vLLM engine once
|
||||
print("\n=== Creating vLLM Engine ===")
|
||||
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
logprobs=1 # Get top-1 logprobs
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=args.temperature, max_tokens=args.max_tokens, logprobs=1) # Get top-1 logprobs
|
||||
|
||||
# Process samples until finding significant mismatch
|
||||
print("\n=== Processing Samples ===")
|
||||
|
||||
|
||||
# Initialize statistics tracking
|
||||
all_results = []
|
||||
for i, sample in enumerate(samples):
|
||||
print(f"\n\n{'#'*80}")
|
||||
print(f"### Processing sample {i+1}/{len(samples)}")
|
||||
print(f"{'#'*80}")
|
||||
|
||||
|
||||
# Process single sample
|
||||
result = process_single_prompt(sample, llm, hf_model, processor_hf, sampling_params, device, args)
|
||||
all_results.append(result)
|
||||
|
||||
|
||||
# Check if we found significant mismatch
|
||||
if result['first_mismatch_idx'] is not None and result['max_prob_first_diff'] > args.prob_threshold:
|
||||
if result["first_mismatch_idx"] is not None and result["max_prob_first_diff"] > args.prob_threshold:
|
||||
print(f"\n\n{'*'*80}")
|
||||
print(f"*** FOUND SIGNIFICANT MISMATCH ***")
|
||||
print(f"*** First mismatch probability difference: {result['max_prob_first_diff']:.6f} > {args.prob_threshold} ***")
|
||||
print(f"*** Stopping after sample {i+1}/{len(samples)} ***")
|
||||
print(f"{'*'*80}")
|
||||
|
||||
|
||||
# Report aggregated statistics
|
||||
print(f"\n\n{'='*80}")
|
||||
print("=== AGGREGATED STATISTICS ===")
|
||||
print(f"{'='*80}")
|
||||
|
||||
|
||||
total_samples = len(all_results)
|
||||
samples_with_mismatches = sum(1 for r in all_results if r['first_mismatch_idx'] is not None)
|
||||
total_tokens_generated = sum(r['num_generated'] for r in all_results)
|
||||
|
||||
samples_with_mismatches = sum(1 for r in all_results if r["first_mismatch_idx"] is not None)
|
||||
total_tokens_generated = sum(r["num_generated"] for r in all_results)
|
||||
|
||||
print(f"Total samples processed: {total_samples}")
|
||||
print(f"Samples with mismatches: {samples_with_mismatches} ({samples_with_mismatches/total_samples*100:.1f}%)")
|
||||
print(f"Total tokens generated: {total_tokens_generated}")
|
||||
|
||||
|
||||
if samples_with_mismatches > 0:
|
||||
avg_match_rate = sum(r['match_rate'] for r in all_results) / total_samples
|
||||
max_prob_diffs = [r['max_prob_first_diff'] for r in all_results if r['first_mismatch_idx'] is not None]
|
||||
avg_match_rate = sum(r["match_rate"] for r in all_results) / total_samples
|
||||
max_prob_diffs = [r["max_prob_first_diff"] for r in all_results if r["first_mismatch_idx"] is not None]
|
||||
avg_prob_diff = sum(max_prob_diffs) / len(max_prob_diffs)
|
||||
max_prob_diff_overall = max(max_prob_diffs)
|
||||
|
||||
first_mismatch_positions = [r['first_mismatch_idx'] for r in all_results if r['first_mismatch_idx'] is not None]
|
||||
|
||||
first_mismatch_positions = [r["first_mismatch_idx"] for r in all_results if r["first_mismatch_idx"] is not None]
|
||||
avg_first_mismatch_pos = sum(first_mismatch_positions) / len(first_mismatch_positions)
|
||||
|
||||
|
||||
print(f"\nMismatch Statistics:")
|
||||
print(f" Average token match rate: {avg_match_rate:.1f}%")
|
||||
print(f" Average first mismatch position: {avg_first_mismatch_pos:.1f}")
|
||||
@ -403,7 +367,7 @@ async def async_main():
|
||||
print(f" Max first mismatch prob diff: {max_prob_diff_overall:.6f}")
|
||||
else:
|
||||
print("\nNo mismatches found in any samples!")
|
||||
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
|
||||
|
||||
@ -412,4 +376,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -7,7 +7,7 @@ Compresses OlmOCR checkpoints using FP8 quantization:
|
||||
|
||||
Usage:
|
||||
python compress_checkpoint.py <source_path> <destination_path> --recipe <recipe_path> [--num-calibration-samples N] [--calibration-pdfs PDF1+PDF2+...]
|
||||
|
||||
|
||||
source_path: Path to checkpoint (local or S3)
|
||||
destination_path: Where to save compressed checkpoint (local or S3)
|
||||
recipe_path: Path to quantization config YAML file
|
||||
@ -26,50 +26,54 @@ import shutil
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union, List
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import boto3
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from llmcompressor import oneshot
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
from olmocr.pipeline import build_page_query
|
||||
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
|
||||
def get_calibration_pdfs(num_samples: int, pdf_paths: List[str]) -> List[str]:
|
||||
"""Get calibration PDFs from provided paths.
|
||||
|
||||
|
||||
Args:
|
||||
num_samples: Number of samples to use
|
||||
pdf_paths: List of local PDF paths
|
||||
|
||||
|
||||
Returns:
|
||||
List of valid PDF paths
|
||||
"""
|
||||
print(f"Using {len(pdf_paths)} provided calibration PDFs")
|
||||
|
||||
|
||||
# If more PDFs provided than needed, randomly sample
|
||||
if len(pdf_paths) > num_samples:
|
||||
pdf_paths = random.sample(pdf_paths, num_samples)
|
||||
print(f"Randomly sampled {num_samples} PDFs from provided paths")
|
||||
|
||||
|
||||
# Verify all PDFs exist
|
||||
valid_paths = []
|
||||
for path in pdf_paths:
|
||||
if os.path.exists(path) and path.endswith('.pdf'):
|
||||
if os.path.exists(path) and path.endswith(".pdf"):
|
||||
valid_paths.append(path)
|
||||
else:
|
||||
print(f" Warning: Skipping invalid path: {path}")
|
||||
|
||||
|
||||
if not valid_paths:
|
||||
raise ValueError("No valid PDF paths found in the provided list")
|
||||
|
||||
|
||||
print(f"Using {len(valid_paths)} valid calibration PDFs")
|
||||
return valid_paths
|
||||
|
||||
@ -77,14 +81,14 @@ def get_calibration_pdfs(num_samples: int, pdf_paths: List[str]) -> List[str]:
|
||||
async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Dataset:
|
||||
"""Prepare calibration dataset from PDFs using build_page_query."""
|
||||
dataset_items = []
|
||||
|
||||
|
||||
for pdf_path in pdf_paths:
|
||||
# Get first page of each PDF (page 0)
|
||||
query = await build_page_query(pdf_path, page=0, target_longest_image_dim=1024)
|
||||
|
||||
|
||||
# Extract the messages
|
||||
messages = query["messages"]
|
||||
|
||||
|
||||
# Extract images from the message content
|
||||
images = []
|
||||
for message in messages:
|
||||
@ -99,12 +103,10 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Datase
|
||||
image_bytes = base64.b64decode(base64_str)
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
images.append(image)
|
||||
|
||||
|
||||
# Apply chat template to get the text
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
# Process with tokenizer
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
@ -113,31 +115,32 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Datase
|
||||
max_length=8192,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
|
||||
dataset_items.append(inputs)
|
||||
|
||||
|
||||
# Convert list of dicts to HuggingFace Dataset
|
||||
if dataset_items:
|
||||
# Create dataset in batches to avoid overflow
|
||||
batch_size = 50 # Process in smaller batches
|
||||
all_datasets = []
|
||||
|
||||
|
||||
for i in range(0, len(dataset_items), batch_size):
|
||||
batch = dataset_items[i:i + batch_size]
|
||||
batch = dataset_items[i : i + batch_size]
|
||||
# Flatten the batch into a dict of lists
|
||||
batch_dict = {}
|
||||
for key in batch[0].keys():
|
||||
batch_dict[key] = [item[key] for item in batch]
|
||||
|
||||
|
||||
# Create dataset for this batch
|
||||
batch_dataset = Dataset.from_dict(batch_dict)
|
||||
all_datasets.append(batch_dataset)
|
||||
|
||||
|
||||
# Concatenate all batch datasets
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
else:
|
||||
from datasets import concatenate_datasets
|
||||
|
||||
return concatenate_datasets(all_datasets)
|
||||
else:
|
||||
return Dataset.from_dict({})
|
||||
@ -151,21 +154,21 @@ def is_s3_path(path: str) -> bool:
|
||||
def download_s3_to_local(bucket: str, prefix: str, local_dir: str) -> None:
|
||||
"""Download all files from S3 prefix to local directory."""
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
||||
|
||||
|
||||
print(f"Downloading checkpoint from s3://{bucket}/{prefix} to {local_dir}...")
|
||||
|
||||
|
||||
for page in pages:
|
||||
for obj in page.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
if key.endswith("/"):
|
||||
continue
|
||||
|
||||
|
||||
rel_path = os.path.relpath(key, prefix)
|
||||
local_path = os.path.join(local_dir, rel_path)
|
||||
|
||||
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
s3_client.download_file(bucket, key, local_path)
|
||||
print(f" Downloaded {rel_path}")
|
||||
@ -174,18 +177,20 @@ def download_s3_to_local(bucket: str, prefix: str, local_dir: str) -> None:
|
||||
def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
|
||||
"""Upload all files from local directory to S3."""
|
||||
print(f"Uploading compressed checkpoint from {local_dir} to s3://{bucket}/{prefix}...")
|
||||
|
||||
|
||||
for root, _, files in os.walk(local_dir):
|
||||
for file in files:
|
||||
local_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(local_path, local_dir)
|
||||
s3_key = os.path.join(prefix, rel_path)
|
||||
|
||||
|
||||
s3_client.upload_file(local_path, bucket, s3_key)
|
||||
print(f" Uploaded {rel_path}")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
|
||||
def load_model_and_tokenizer(
|
||||
source_path: str,
|
||||
) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
|
||||
"""Load model and tokenizer from source path (local or S3)."""
|
||||
if is_s3_path(source_path):
|
||||
# Download from S3 to temporary directory
|
||||
@ -196,53 +201,37 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditio
|
||||
else:
|
||||
model_path = source_path
|
||||
temp_dir = None
|
||||
|
||||
|
||||
# Read config to determine model architecture
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
# Get model name from config
|
||||
model_name = config.get("name_or_path", "")
|
||||
|
||||
|
||||
print(f"Loading model from {model_path}...")
|
||||
|
||||
|
||||
# Load appropriate model class based on name
|
||||
if "Qwen2.5-VL" in model_name:
|
||||
print("Detected Qwen2.5-VL model")
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
|
||||
elif "Qwen2-VL" in model_name:
|
||||
print("Detected Qwen2-VL model")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
|
||||
else:
|
||||
# Default to checking architectures list
|
||||
architectures = config.get("architectures", [])
|
||||
if "Qwen2_5_VLForConditionalGeneration" in architectures:
|
||||
print("Detected Qwen2.5-VL model from architectures")
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
|
||||
else:
|
||||
print("Detected Qwen2-VL model from architectures")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
|
||||
|
||||
print(f"Loading tokenizer from {model_path}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
|
||||
return model, tokenizer, temp_dir
|
||||
|
||||
|
||||
@ -250,10 +239,10 @@ def copy_additional_files(source_path: str, dest_path: str, temp_source_dir: Opt
|
||||
"""Copy additional config files that are needed but not saved by save_pretrained."""
|
||||
# List of additional files to copy if they exist
|
||||
additional_files = ["preprocessor_config.json", "chat_template.json"]
|
||||
|
||||
|
||||
# Determine the actual source path (could be temp dir if downloaded from S3)
|
||||
actual_source = temp_source_dir if temp_source_dir else source_path
|
||||
|
||||
|
||||
for filename in additional_files:
|
||||
source_file = os.path.join(actual_source, filename)
|
||||
if os.path.exists(source_file):
|
||||
@ -268,55 +257,50 @@ def data_collator(batch):
|
||||
return {key: torch.tensor(value) for key, value in batch[0].items()}
|
||||
|
||||
|
||||
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 512, calibration_pdfs: Optional[List[str]] = None) -> None:
|
||||
def compress_checkpoint(
|
||||
source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 512, calibration_pdfs: Optional[List[str]] = None
|
||||
) -> None:
|
||||
"""Compress OlmOCR checkpoint using FP8 quantization."""
|
||||
# Load model and tokenizer
|
||||
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
|
||||
|
||||
|
||||
try:
|
||||
# Print all model tensor names
|
||||
print("\n=== Model Tensor Names ===")
|
||||
for name, param in model.named_parameters():
|
||||
print(f"{name}: shape={list(param.shape)}, dtype={param.dtype}")
|
||||
print("=========================\n")
|
||||
|
||||
|
||||
# Prepare calibration dataset if requested
|
||||
dataset = None
|
||||
|
||||
|
||||
if num_calibration_samples > 0:
|
||||
if not calibration_pdfs:
|
||||
raise ValueError("Calibration PDFs must be provided when num_calibration_samples > 0. Use --calibration-pdfs argument.")
|
||||
|
||||
|
||||
print(f"\nPreparing calibration dataset with {num_calibration_samples} samples...")
|
||||
|
||||
|
||||
# Load processor for the model
|
||||
processor = AutoProcessor.from_pretrained(source_path if not temp_source_dir else temp_source_dir)
|
||||
|
||||
|
||||
# Get calibration PDFs from provided paths
|
||||
pdf_paths = get_calibration_pdfs(num_calibration_samples, calibration_pdfs)
|
||||
|
||||
|
||||
# Prepare dataset
|
||||
dataset = asyncio.run(prepare_calibration_dataset(pdf_paths, processor))
|
||||
|
||||
|
||||
print(f"✓ Prepared {len(dataset)} calibration samples")
|
||||
|
||||
|
||||
# Apply quantization using provided recipe
|
||||
print(f"\nApplying quantization using recipe: {recipe_path}")
|
||||
|
||||
|
||||
if dataset:
|
||||
oneshot(
|
||||
model=model,
|
||||
recipe=recipe_path,
|
||||
dataset=dataset,
|
||||
max_seq_length=8192,
|
||||
num_calibration_samples=len(dataset),
|
||||
data_collator=data_collator
|
||||
)
|
||||
oneshot(model=model, recipe=recipe_path, dataset=dataset, max_seq_length=8192, num_calibration_samples=len(dataset), data_collator=data_collator)
|
||||
else:
|
||||
oneshot(model=model, recipe=recipe_path)
|
||||
|
||||
|
||||
print("✓ Quantization completed successfully")
|
||||
|
||||
|
||||
# Save the compressed model
|
||||
if is_s3_path(dest_path):
|
||||
# Save to temporary directory first, then upload to S3
|
||||
@ -324,10 +308,10 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_
|
||||
print(f"\nSaving compressed model to temporary directory...")
|
||||
model.save_pretrained(temp_dest_dir)
|
||||
tokenizer.save_pretrained(temp_dest_dir)
|
||||
|
||||
|
||||
# Copy additional files
|
||||
copy_additional_files(source_path, temp_dest_dir, temp_source_dir)
|
||||
|
||||
|
||||
# Upload to S3
|
||||
bucket, prefix = parse_s3_path(dest_path)
|
||||
upload_local_to_s3(temp_dest_dir, bucket, prefix)
|
||||
@ -337,18 +321,18 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_
|
||||
os.makedirs(dest_path, exist_ok=True)
|
||||
model.save_pretrained(dest_path)
|
||||
tokenizer.save_pretrained(dest_path)
|
||||
|
||||
|
||||
# Copy additional files
|
||||
copy_additional_files(source_path, dest_path, temp_source_dir)
|
||||
|
||||
|
||||
print(f"\n✓ Successfully compressed checkpoint and saved to {dest_path}")
|
||||
|
||||
|
||||
finally:
|
||||
# Clean up temporary source directory if needed
|
||||
if temp_source_dir:
|
||||
print(f"Cleaning up temporary directory {temp_source_dir}...")
|
||||
shutil.rmtree(temp_source_dir)
|
||||
|
||||
|
||||
# Free up GPU memory
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
@ -377,48 +361,49 @@ Examples:
|
||||
|
||||
# Using recursive glob pattern
|
||||
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/**/*.pdf"
|
||||
"""
|
||||
""",
|
||||
)
|
||||
parser.add_argument("source", help="Source checkpoint path (local or S3)")
|
||||
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
|
||||
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
|
||||
parser.add_argument("--num-calibration-samples", type=int, default=512,
|
||||
help="Number of calibration samples to use (default: 512, set to 0 to disable)")
|
||||
parser.add_argument("--calibration-pdfs", type=str, default=None,
|
||||
help="Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/*.pdf' or '/data/**/*.pdf'). Required when num-calibration-samples > 0.")
|
||||
|
||||
parser.add_argument("--num-calibration-samples", type=int, default=512, help="Number of calibration samples to use (default: 512, set to 0 to disable)")
|
||||
parser.add_argument(
|
||||
"--calibration-pdfs",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/*.pdf' or '/data/**/*.pdf'). Required when num-calibration-samples > 0.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Parse calibration PDFs if provided
|
||||
calibration_pdfs = None
|
||||
if args.calibration_pdfs:
|
||||
# Use pathlib for better glob handling
|
||||
pattern = args.calibration_pdfs
|
||||
|
||||
|
||||
# Handle both absolute and relative paths with recursive glob
|
||||
if '**' in pattern:
|
||||
if "**" in pattern:
|
||||
# For recursive patterns, we need to handle them specially
|
||||
if pattern.startswith('/'):
|
||||
if pattern.startswith("/"):
|
||||
# Absolute path with **
|
||||
parts = pattern.split('**')
|
||||
parts = pattern.split("**")
|
||||
base_path = Path(parts[0])
|
||||
glob_pattern = '**' + parts[1] if len(parts) > 1 else '**/*.pdf'
|
||||
calibration_pdfs = list(base_path.glob(glob_pattern.lstrip('/')))
|
||||
glob_pattern = "**" + parts[1] if len(parts) > 1 else "**/*.pdf"
|
||||
calibration_pdfs = list(base_path.glob(glob_pattern.lstrip("/")))
|
||||
else:
|
||||
# Relative path with **
|
||||
calibration_pdfs = list(Path('.').glob(pattern))
|
||||
calibration_pdfs = [str(p.absolute()) for p in calibration_pdfs if p.suffix.lower() == '.pdf']
|
||||
calibration_pdfs = list(Path(".").glob(pattern))
|
||||
calibration_pdfs = [str(p.absolute()) for p in calibration_pdfs if p.suffix.lower() == ".pdf"]
|
||||
else:
|
||||
# Use standard glob for non-recursive patterns
|
||||
calibration_pdfs = glob.glob(pattern)
|
||||
calibration_pdfs = [p for p in calibration_pdfs if p.lower().endswith('.pdf')]
|
||||
|
||||
calibration_pdfs = [p for p in calibration_pdfs if p.lower().endswith(".pdf")]
|
||||
|
||||
print(f"Found {len(calibration_pdfs)} PDF files matching pattern: {args.calibration_pdfs}")
|
||||
|
||||
|
||||
compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples, calibration_pdfs)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
|
||||
@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -200,7 +200,7 @@ class TrainingConfig:
|
||||
# Performance
|
||||
dataloader_drop_last: bool = True
|
||||
dataloader_num_workers: int = 16
|
||||
|
||||
|
||||
# Data collator settings
|
||||
collator_max_token_len: Optional[int] = None
|
||||
remove_unused_columns: bool = False # Important for custom datasets
|
||||
@ -314,10 +314,10 @@ class Config:
|
||||
FrontMatterOutputFormat,
|
||||
FrontMatterParser,
|
||||
InstructUserMessages,
|
||||
JSONOutputFormat,
|
||||
LatexBracketNormalizer,
|
||||
NewYamlFinetuningPromptWithAnchoring,
|
||||
NewYamlFinetuningPromptWithNoAnchoring,
|
||||
JSONOutputFormat,
|
||||
PDFRenderer,
|
||||
RandomTokenFlipper,
|
||||
StaticLengthDocumentAnchoring,
|
||||
@ -382,18 +382,18 @@ class Config:
|
||||
if processor is None:
|
||||
raise ValueError("Processor must be provided for RandomTokenFlipper step (to get valid tokens)")
|
||||
tokenizer = processor.tokenizer
|
||||
|
||||
|
||||
# Get all special token IDs to exclude
|
||||
special_token_ids = set()
|
||||
for token in tokenizer.all_special_tokens:
|
||||
special_token_ids.add(tokenizer.convert_tokens_to_ids(token))
|
||||
|
||||
|
||||
# Get all token IDs that are not special tokens
|
||||
valid_token_ids = []
|
||||
for token_id in range(len(tokenizer)):
|
||||
if token_id not in special_token_ids:
|
||||
valid_token_ids.append(token_id)
|
||||
|
||||
|
||||
steps.append(
|
||||
RandomTokenFlipper(
|
||||
valid_token_ids=valid_token_ids,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
@ -8,8 +9,7 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeAlias, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
@ -31,10 +31,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def validate_pdf_pair(md_path: Path) -> Tuple[Optional[Dict[str, Path]], Optional[Tuple[Path, str]]]:
|
||||
"""Validate a single markdown-PDF pair.
|
||||
|
||||
|
||||
Args:
|
||||
md_path: Path to the markdown file
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (valid_sample, invalid_pdf_info)
|
||||
- valid_sample: Dict with markdown_path and pdf_path if valid, None otherwise
|
||||
@ -42,31 +42,32 @@ def validate_pdf_pair(md_path: Path) -> Tuple[Optional[Dict[str, Path]], Optiona
|
||||
"""
|
||||
# Look for PDF with same stem (filename without extension)
|
||||
pdf_path = md_path.with_suffix(".pdf")
|
||||
|
||||
|
||||
if pdf_path.exists() or pdf_path.is_symlink():
|
||||
# Resolve symlink if it is one
|
||||
if pdf_path.is_symlink():
|
||||
pdf_path = pdf_path.resolve()
|
||||
|
||||
|
||||
# Verify the resolved path exists
|
||||
if pdf_path.exists():
|
||||
# Validate PDF - check it loads and has exactly one page and that you can get document-anchoring from it
|
||||
try:
|
||||
reader = PdfReader(str(pdf_path))
|
||||
num_pages = len(reader.pages)
|
||||
|
||||
|
||||
if num_pages != 1:
|
||||
return None, (pdf_path, f"Expected 1 page, found {num_pages}")
|
||||
|
||||
|
||||
# Test that document anchoring works
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
|
||||
get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100)
|
||||
|
||||
|
||||
return {"markdown_path": md_path, "pdf_path": pdf_path}, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return None, (pdf_path, f"Failed to load: {str(e)}")
|
||||
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
@ -104,29 +105,29 @@ class BaseMarkdownPDFDataset(Dataset):
|
||||
invalid_pdfs = []
|
||||
|
||||
logger.info(f"Validating {len(md_files)} markdown-PDF pairs using ProcessPoolExecutor...")
|
||||
|
||||
|
||||
# Use ProcessPoolExecutor for parallel validation
|
||||
with ProcessPoolExecutor(max_workers=8) as executor:
|
||||
# Submit all validation tasks
|
||||
future_to_md = {executor.submit(validate_pdf_pair, md_path): md_path for md_path in md_files}
|
||||
|
||||
|
||||
# Process results as they complete
|
||||
with tqdm(total=len(md_files), desc="Validating PDFs") as pbar:
|
||||
for future in as_completed(future_to_md):
|
||||
md_path = future_to_md[future]
|
||||
try:
|
||||
valid_sample, invalid_pdf_info = future.result()
|
||||
|
||||
|
||||
if valid_sample:
|
||||
self.samples.append(valid_sample)
|
||||
valid_count += 1
|
||||
elif invalid_pdf_info:
|
||||
invalid_pdfs.append(invalid_pdf_info)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {md_path}: {str(e)}")
|
||||
invalid_pdfs.append((md_path.with_suffix(".pdf"), f"Processing error: {str(e)}"))
|
||||
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
logger.info(f"Found {valid_count} valid markdown-PDF pairs")
|
||||
@ -205,11 +206,11 @@ class FrontMatterParser(PipelineStep):
|
||||
value = front_matter_dict[field_name]
|
||||
|
||||
# Handle type conversions
|
||||
if field_type == int and isinstance(value, str):
|
||||
if field_type is int and isinstance(value, str):
|
||||
kwargs[field_name] = int(value)
|
||||
elif field_type == bool and isinstance(value, str):
|
||||
elif field_type is bool and isinstance(value, str):
|
||||
kwargs[field_name] = value.lower() == "true"
|
||||
elif field_type == Optional[str]:
|
||||
elif field_type is Optional[str]:
|
||||
kwargs[field_name] = value if value else None
|
||||
else:
|
||||
kwargs[field_name] = value
|
||||
@ -288,7 +289,7 @@ class FinetuningPrompt(PipelineStep):
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
sample["instruction_prompt"] = build_finetuning_prompt(sample["anchor_text"])
|
||||
return sample
|
||||
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class NewYamlFinetuningPromptWithAnchoring(PipelineStep):
|
||||
@ -323,7 +324,7 @@ class FrontMatterOutputFormat(PipelineStep):
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
page_data = sample["page_data"]
|
||||
assert type(page_data) == PageResponse
|
||||
assert type(page_data) is PageResponse
|
||||
|
||||
sample["response"] = (
|
||||
f"""---
|
||||
@ -346,58 +347,63 @@ class JSONOutputFormat(PipelineStep):
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
page_data = sample["page_data"]
|
||||
assert type(page_data) == PageResponse
|
||||
assert type(page_data) is PageResponse
|
||||
|
||||
sample["response"] = json.dumps({
|
||||
"primary_language": page_data.primary_language,
|
||||
"is_rotation_valid": page_data.is_rotation_valid,
|
||||
"rotation_correction": page_data.rotation_correction,
|
||||
"is_table": page_data.is_table,
|
||||
"is_diagram": page_data.is_diagram,
|
||||
"natural_text": page_data.natural_text
|
||||
}, ensure_ascii=False)
|
||||
sample["response"] = json.dumps(
|
||||
{
|
||||
"primary_language": page_data.primary_language,
|
||||
"is_rotation_valid": page_data.is_rotation_valid,
|
||||
"rotation_correction": page_data.rotation_correction,
|
||||
"is_table": page_data.is_table,
|
||||
"is_diagram": page_data.is_diagram,
|
||||
"natural_text": page_data.natural_text,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LatexBracketNormalizer(PipelineStep):
|
||||
"""Normalizes LaTeX brackets in natural text field."""
|
||||
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
"""Normalize LaTeX brackets in the natural text field."""
|
||||
# Get the page_data object
|
||||
if "page_data" not in sample:
|
||||
return sample
|
||||
|
||||
|
||||
page_data = sample["page_data"]
|
||||
if not hasattr(page_data, "natural_text") or not page_data.natural_text:
|
||||
return sample
|
||||
|
||||
|
||||
text = page_data.natural_text
|
||||
|
||||
|
||||
# Define patterns for LaTeX normalization
|
||||
# Order matters: process display math first, then inline
|
||||
patterns = [
|
||||
(r"\$\$(.+?)\$\$", r"\[\1\]"), # $$...$$ to \[...\]
|
||||
(r"\$(.+?)\$", r"\(\1\)"), # $...$ to \(...\)
|
||||
(r"\$(.+?)\$", r"\(\1\)"), # $...$ to \(...\)
|
||||
]
|
||||
|
||||
|
||||
# Apply replacements
|
||||
for pattern, replacement in patterns:
|
||||
text = re.sub(pattern, replacement, text, flags=re.DOTALL)
|
||||
|
||||
|
||||
# Update the page_data with normalized text
|
||||
# Since PageResponse is frozen, we need to create a new instance
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
|
||||
new_page_data = PageResponse(
|
||||
primary_language=page_data.primary_language,
|
||||
is_rotation_valid=page_data.is_rotation_valid,
|
||||
rotation_correction=page_data.rotation_correction,
|
||||
is_table=page_data.is_table,
|
||||
is_diagram=page_data.is_diagram,
|
||||
natural_text=text
|
||||
natural_text=text,
|
||||
)
|
||||
|
||||
|
||||
sample["page_data"] = new_page_data
|
||||
return sample
|
||||
|
||||
@ -493,26 +499,26 @@ class Tokenizer(PipelineStep):
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RandomTokenFlipper(PipelineStep):
|
||||
"""Randomly flips tokens in the output (non-masked) portion and masks their labels."""
|
||||
|
||||
|
||||
valid_token_ids: List[int] # List of valid token IDs to substitute with
|
||||
token_flip_rate: float = 1e-4
|
||||
masking_index: int = -100
|
||||
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
"""Randomly flip tokens in the non-masked portion of labels."""
|
||||
if "labels" not in sample or "input_ids" not in sample:
|
||||
return sample
|
||||
|
||||
|
||||
# Work with copies to avoid modifying original arrays
|
||||
labels = sample["labels"].copy()
|
||||
input_ids = sample["input_ids"].copy()
|
||||
|
||||
|
||||
# Find indices where labels are not masked (i.e., output tokens)
|
||||
non_masked_indices = np.where(labels != self.masking_index)[0]
|
||||
|
||||
|
||||
if len(non_masked_indices) == 0:
|
||||
return sample
|
||||
|
||||
|
||||
# For each non-masked token, independently decide whether to flip
|
||||
for idx in non_masked_indices:
|
||||
if np.random.random() < self.token_flip_rate:
|
||||
@ -520,11 +526,11 @@ class RandomTokenFlipper(PipelineStep):
|
||||
random_token = np.random.choice(self.valid_token_ids)
|
||||
input_ids[idx] = random_token
|
||||
labels[idx] = self.masking_index
|
||||
|
||||
|
||||
# Update sample with modified arrays
|
||||
sample["input_ids"] = input_ids
|
||||
sample["labels"] = labels
|
||||
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@ -590,26 +596,27 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Import config module
|
||||
from olmocr.train.config import Config
|
||||
|
||||
# Load configuration
|
||||
print(f"\n=== Loading configuration from {args.config} ===")
|
||||
config = Config.from_yaml(args.config)
|
||||
|
||||
|
||||
# Validate configuration
|
||||
try:
|
||||
config.validate()
|
||||
except ValueError as e:
|
||||
print(f"Configuration validation failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
# Load processor for tokenization
|
||||
print(f"\nLoading processor: {config.model.name}")
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(config.model.name)
|
||||
|
||||
|
||||
# Select dataset based on type
|
||||
if args.dataset_type == "train":
|
||||
dataset_configs = config.dataset.train
|
||||
@ -617,19 +624,19 @@ if __name__ == "__main__":
|
||||
else:
|
||||
dataset_configs = config.dataset.eval
|
||||
dataset_name = "eval"
|
||||
|
||||
|
||||
if args.dataset_index >= len(dataset_configs):
|
||||
print(f"Error: Dataset index {args.dataset_index} out of range. Only {len(dataset_configs)} {dataset_name} datasets available.")
|
||||
exit(1)
|
||||
|
||||
|
||||
dataset_cfg = dataset_configs[args.dataset_index]
|
||||
root_dir = dataset_cfg["root_dir"]
|
||||
pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
|
||||
|
||||
|
||||
print(f"\n=== Testing {dataset_name} dataset {args.dataset_index} ===")
|
||||
print(f"Root directory: {root_dir}")
|
||||
print(f"Pipeline steps: {[step.__class__.__name__ for step in pipeline_steps]}")
|
||||
|
||||
|
||||
# Create dataset
|
||||
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
||||
|
||||
@ -641,7 +648,7 @@ if __name__ == "__main__":
|
||||
for i in range(min(5, len(dataset))):
|
||||
sample = dataset.samples[i]
|
||||
print(f" {i}: MD: {sample['markdown_path'].name}, PDF: {sample['pdf_path'].name}")
|
||||
|
||||
|
||||
# Check if sample index is valid
|
||||
if args.sample_index >= len(dataset):
|
||||
print(f"\nError: Sample index {args.sample_index} out of range. Only {len(dataset)} samples available.")
|
||||
@ -650,39 +657,39 @@ if __name__ == "__main__":
|
||||
# Get the requested sample
|
||||
print(f"\n=== Displaying sample {args.sample_index} ===")
|
||||
sample = dataset[args.sample_index]
|
||||
|
||||
|
||||
# Display sample information based on pipeline output
|
||||
print("\nSample keys:", list(sample.keys()))
|
||||
|
||||
|
||||
# If it's raw data (no tokenization)
|
||||
if 'markdown_path' in sample:
|
||||
if "markdown_path" in sample:
|
||||
print(f"\nMarkdown file: {sample['markdown_path'].name}")
|
||||
if 'pdf_path' in sample:
|
||||
if "pdf_path" in sample:
|
||||
print(f"PDF file: {sample['pdf_path'].name}")
|
||||
if 'image' in sample and hasattr(sample['image'], 'size'):
|
||||
if "image" in sample and hasattr(sample["image"], "size"):
|
||||
print(f"Image size: {sample['image'].size}")
|
||||
if 'page_data' in sample:
|
||||
if "page_data" in sample:
|
||||
print(f"\nPage data: {sample['page_data']}")
|
||||
if 'messages' in sample:
|
||||
if "messages" in sample:
|
||||
print(f"\n=== Messages ===")
|
||||
for i, msg in enumerate(sample['messages']):
|
||||
for i, msg in enumerate(sample["messages"]):
|
||||
print(f"\nMessage {i}:")
|
||||
print(f" Role: {msg['role']}")
|
||||
print(f" Content preview: {str(msg['content'])[:200]}...")
|
||||
|
||||
|
||||
# If it's tokenized data
|
||||
if 'input_ids' in sample:
|
||||
if "input_ids" in sample:
|
||||
print(f"\n=== Tokenized Output ===")
|
||||
print(f" Keys: {list(sample.keys())}")
|
||||
print(f" Input IDs shape: {sample['input_ids'].shape}")
|
||||
print(f" Labels shape: {sample['labels'].shape}")
|
||||
print(f" Attention mask shape: {sample['attention_mask'].shape}")
|
||||
|
||||
|
||||
if "pixel_values" in sample:
|
||||
print(f" Pixel values shape: {sample['pixel_values'].shape}")
|
||||
if "image_grid_thw" in sample:
|
||||
print(f" Image grid THW: {sample['image_grid_thw']}")
|
||||
|
||||
|
||||
# Show label masking
|
||||
print(f"\nLabel masking analysis:")
|
||||
labels = sample["labels"]
|
||||
@ -691,29 +698,29 @@ if __name__ == "__main__":
|
||||
print(f" Total tokens: {total_count}")
|
||||
print(f" Masked tokens: {masked_count} ({masked_count/total_count*100:.1f}%)")
|
||||
print(f" Unmasked tokens: {total_count - masked_count} ({(total_count - masked_count)/total_count*100:.1f}%)")
|
||||
|
||||
|
||||
# Find the transition point
|
||||
transition_idx = None
|
||||
for i in range(len(labels) - 1):
|
||||
if labels[i] == -100 and labels[i + 1] != -100:
|
||||
transition_idx = i + 1
|
||||
break
|
||||
|
||||
|
||||
if transition_idx:
|
||||
print(f" Transition from masked to unmasked at position: {transition_idx}")
|
||||
|
||||
|
||||
# Print all tokens
|
||||
input_ids = sample["input_ids"]
|
||||
print(f"\nAll tokens ({len(input_ids)} total):")
|
||||
print("Format: [index] Token (repr) | Label | Token ID")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
for i in range(len(input_ids)):
|
||||
token = processor.tokenizer.decode([input_ids[i]])
|
||||
token_repr = repr(token)
|
||||
label = labels[i] if i < len(labels) else "N/A"
|
||||
token_id = input_ids[i]
|
||||
|
||||
|
||||
# Mark special positions
|
||||
marker = ""
|
||||
if transition_idx and i == transition_idx:
|
||||
@ -722,64 +729,65 @@ if __name__ == "__main__":
|
||||
marker = " <-- START"
|
||||
elif label != -100 and i > 0 and labels[i - 1] == -100:
|
||||
marker = " <-- response begins"
|
||||
|
||||
|
||||
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
|
||||
|
||||
|
||||
# Calculate and show token statistics after the table
|
||||
print(f"\nToken statistics:")
|
||||
|
||||
|
||||
# Count consecutive high-value tokens that represent the image
|
||||
# Qwen uses tokens like 151859, 151860, etc. for image patches
|
||||
image_token_threshold = 151000 # Typical threshold for Qwen image tokens
|
||||
image_token_count = np.sum(input_ids > image_token_threshold)
|
||||
|
||||
|
||||
# Calculate prompt tokens (everything masked)
|
||||
prompt_token_count = masked_count
|
||||
|
||||
|
||||
# Calculate output tokens (everything not masked)
|
||||
output_token_count = total_count - masked_count
|
||||
|
||||
|
||||
# Calculate non-image prompt tokens
|
||||
non_image_prompt_tokens = prompt_token_count - image_token_count
|
||||
|
||||
|
||||
print(f" Image tokens: {image_token_count}")
|
||||
print(f" Prompt tokens (total): {prompt_token_count}")
|
||||
print(f" Prompt tokens (non-image): {non_image_prompt_tokens}")
|
||||
print(f" Output tokens: {output_token_count}")
|
||||
print(f" Total sequence length: {total_count}")
|
||||
|
||||
|
||||
# Analyze token length distribution across entire dataset
|
||||
if 'input_ids' in sample:
|
||||
if "input_ids" in sample:
|
||||
print(f"\n\n=== Analyzing token length distribution across entire dataset ===")
|
||||
print(f"Processing {len(dataset)} samples...")
|
||||
|
||||
|
||||
# Function to process a single sample
|
||||
def process_sample(idx):
|
||||
try:
|
||||
current_sample = dataset[idx]
|
||||
if 'labels' in current_sample:
|
||||
if "labels" in current_sample:
|
||||
# Count total sequence length (all tokens, prompt + completion)
|
||||
labels = current_sample['labels']
|
||||
labels = current_sample["labels"]
|
||||
total_length = len(labels)
|
||||
return (idx, total_length, None)
|
||||
return (idx, None, "No labels in sample")
|
||||
except Exception as e:
|
||||
return (idx, None, str(e))
|
||||
|
||||
|
||||
# Process samples in parallel with progress bar
|
||||
sequence_lengths = []
|
||||
max_sequence_length = 0
|
||||
max_sequence_sample_idx = 0
|
||||
errors = []
|
||||
|
||||
|
||||
# Determine number of workers (use fewer workers to avoid memory issues)
|
||||
import multiprocessing
|
||||
|
||||
num_workers = min(multiprocessing.cpu_count() // 2, 8)
|
||||
|
||||
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
||||
# Submit all tasks
|
||||
futures = {executor.submit(process_sample, idx): idx for idx in range(len(dataset))}
|
||||
|
||||
|
||||
# Process results with progress bar
|
||||
with tqdm(total=len(dataset), desc="Analyzing samples") as pbar:
|
||||
for future in as_completed(futures):
|
||||
@ -796,16 +804,16 @@ if __name__ == "__main__":
|
||||
except Exception as e:
|
||||
errors.append((idx, f"Future error: {e}"))
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
if errors:
|
||||
print(f"\nEncountered {len(errors)} errors during processing")
|
||||
if len(errors) <= 5:
|
||||
for idx, error in errors:
|
||||
print(f" Sample {idx}: {error}")
|
||||
|
||||
|
||||
if sequence_lengths:
|
||||
sequence_lengths = np.array(sequence_lengths)
|
||||
|
||||
|
||||
print(f"\nTotal sequence length statistics (prompt + completion):")
|
||||
print(f" Total samples analyzed: {len(sequence_lengths)}")
|
||||
print(f" Max sequence length: {max_sequence_length} tokens (sample index: {max_sequence_sample_idx})")
|
||||
@ -813,37 +821,37 @@ if __name__ == "__main__":
|
||||
print(f" Mean sequence length: {np.mean(sequence_lengths):.1f} tokens")
|
||||
print(f" Median sequence length: {np.median(sequence_lengths):.1f} tokens")
|
||||
print(f" Std dev: {np.std(sequence_lengths):.1f} tokens")
|
||||
|
||||
|
||||
# Create histogram with 100-token buckets
|
||||
print(f"\nSequence length histogram (100-token buckets):")
|
||||
|
||||
|
||||
# Define buckets
|
||||
bucket_size = 100
|
||||
max_bucket = ((max_sequence_length // bucket_size) + 1) * bucket_size
|
||||
buckets = list(range(0, max_bucket + bucket_size, bucket_size))
|
||||
|
||||
|
||||
# Count samples in each bucket
|
||||
hist, _ = np.histogram(sequence_lengths, bins=buckets)
|
||||
|
||||
|
||||
# Find max count for scaling
|
||||
max_count = max(hist)
|
||||
bar_width = 50 # Width of histogram bars
|
||||
|
||||
|
||||
print(f"\n{'Range':>15} | {'Count':>6} | Distribution")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
for i in range(len(hist)):
|
||||
start = buckets[i]
|
||||
end = buckets[i + 1] - 1
|
||||
count = hist[i]
|
||||
|
||||
|
||||
# Create bar
|
||||
if max_count > 0:
|
||||
bar_length = int((count / max_count) * bar_width)
|
||||
bar = "█" * bar_length
|
||||
else:
|
||||
bar = ""
|
||||
|
||||
|
||||
range_str = f"{start:>5}-{end:>5}"
|
||||
print(f"{range_str:>15} | {count:>6} | {bar}")
|
||||
|
||||
|
||||
@ -9,22 +9,22 @@ Supports model souping (averaging weights of multiple checkpoints).
|
||||
|
||||
Usage:
|
||||
python prepare_olmocr_checkpoint.py <source_path> <destination_path>
|
||||
|
||||
|
||||
source_path: Path to checkpoint (local or S3)
|
||||
destination_path: Where to save prepared checkpoint (local or S3)
|
||||
|
||||
For souping multiple checkpoints:
|
||||
python prepare_olmocr_checkpoint.py <source1> <source2> ... <destination>
|
||||
|
||||
|
||||
This will average the weights of all sources and prepare the souped checkpoint.
|
||||
|
||||
Examples:
|
||||
# Single local to local
|
||||
python prepare_olmocr_checkpoint.py /path/to/checkpoint /path/to/output
|
||||
|
||||
|
||||
# Souping multiple S3 to S3
|
||||
python prepare_olmocr_checkpoint.py s3://bucket/ckpt1 s3://bucket/ckpt2 s3://bucket/souped
|
||||
|
||||
|
||||
# Mixed souping
|
||||
python prepare_olmocr_checkpoint.py s3://bucket/ckpt1 /local/ckpt2 s3://bucket/souped
|
||||
"""
|
||||
@ -38,9 +38,9 @@ import tempfile
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
import torch
|
||||
from smart_open import smart_open
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
from safetensors.torch import load_file, save_file
|
||||
@ -50,32 +50,16 @@ except ImportError:
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
|
||||
# Hugging Face model IDs for tokenizer files
|
||||
HF_MODEL_IDS = {
|
||||
"Qwen2VLForConditionalGeneration": "Qwen/Qwen2-VL-7B-Instruct",
|
||||
"Qwen2_5_VLForConditionalGeneration": "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
}
|
||||
HF_MODEL_IDS = {"Qwen2VLForConditionalGeneration": "Qwen/Qwen2-VL-7B-Instruct", "Qwen2_5_VLForConditionalGeneration": "Qwen/Qwen2.5-VL-7B-Instruct"}
|
||||
|
||||
# Required tokenizer files to download from Hugging Face
|
||||
TOKENIZER_FILES = [
|
||||
"chat_template.json",
|
||||
"merges.txt",
|
||||
"preprocessor_config.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"vocab.json"
|
||||
]
|
||||
TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"]
|
||||
|
||||
# Supported model architectures
|
||||
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
|
||||
|
||||
# Files to exclude from copying (training-related files)
|
||||
EXCLUDED_FILES = {
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"rng_state.pth",
|
||||
"trainer_state.json",
|
||||
"training_args.bin"
|
||||
}
|
||||
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin"}
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
@ -89,34 +73,34 @@ def download_file_from_hf(filename: str, destination_dir: str, hf_base_url: str)
|
||||
"""Download a file from Hugging Face model repository."""
|
||||
url = f"{hf_base_url}/{filename}"
|
||||
local_path = os.path.join(destination_dir, filename)
|
||||
|
||||
|
||||
print(f"Downloading {filename} from Hugging Face...")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
with open(local_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
|
||||
print(f"Downloaded {filename}")
|
||||
|
||||
|
||||
def detect_checkpoint_architecture(config_path: str) -> str:
|
||||
"""Detect and validate the checkpoint architecture."""
|
||||
print(f"Detecting checkpoint architecture from {config_path}...")
|
||||
|
||||
|
||||
with smart_open(config_path, "r") as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
|
||||
architectures = config_data.get("architectures", [])
|
||||
|
||||
|
||||
# Find the supported architecture
|
||||
detected_architecture = None
|
||||
for arch in architectures:
|
||||
if arch in SUPPORTED_ARCHITECTURES:
|
||||
detected_architecture = arch
|
||||
break
|
||||
|
||||
|
||||
if not detected_architecture:
|
||||
# Try to detect from model name
|
||||
model_name = config_data.get("name_or_path", "")
|
||||
@ -125,11 +109,8 @@ def detect_checkpoint_architecture(config_path: str) -> str:
|
||||
elif "Qwen2-VL" in model_name:
|
||||
detected_architecture = "Qwen2VLForConditionalGeneration"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No supported architecture found. Expected one of {SUPPORTED_ARCHITECTURES} "
|
||||
f"but found: {architectures}"
|
||||
)
|
||||
|
||||
raise ValueError(f"No supported architecture found. Expected one of {SUPPORTED_ARCHITECTURES} " f"but found: {architectures}")
|
||||
|
||||
print(f"✓ Detected architecture: {detected_architecture}")
|
||||
return detected_architecture
|
||||
|
||||
@ -137,7 +118,7 @@ def detect_checkpoint_architecture(config_path: str) -> str:
|
||||
def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
|
||||
"""Copy files from local directory to local directory."""
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
|
||||
# Get list of files to copy
|
||||
files_to_copy = []
|
||||
for root, _, files in os.walk(source_dir):
|
||||
@ -148,9 +129,9 @@ def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
|
||||
src_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(src_path, source_dir)
|
||||
files_to_copy.append((src_path, os.path.join(dest_dir, rel_path)))
|
||||
|
||||
|
||||
print(f"Copying {len(files_to_copy)} files from {source_dir} to {dest_dir}...")
|
||||
|
||||
|
||||
for src_path, dst_path in tqdm(files_to_copy, desc="Copying files"):
|
||||
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
|
||||
shutil.copy2(src_path, dst_path)
|
||||
@ -170,35 +151,32 @@ def upload_file_to_s3(local_path: str, bucket: str, key: str) -> None:
|
||||
def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> None:
|
||||
"""Copy files from S3 to local directory."""
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
|
||||
# List all objects in source
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=source_bucket, Prefix=source_prefix)
|
||||
|
||||
|
||||
download_tasks = []
|
||||
for page in pages:
|
||||
for obj in page.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
if key.endswith("/"):
|
||||
continue
|
||||
|
||||
|
||||
filename = os.path.basename(key)
|
||||
if filename in EXCLUDED_FILES:
|
||||
print(f"Skipping excluded file: {filename}")
|
||||
continue
|
||||
|
||||
|
||||
rel_path = os.path.relpath(key, source_prefix)
|
||||
local_path = os.path.join(dest_dir, rel_path)
|
||||
download_tasks.append((source_bucket, key, local_path))
|
||||
|
||||
|
||||
print(f"Downloading {len(download_tasks)} files from s3://{source_bucket}/{source_prefix} to {dest_dir}...")
|
||||
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [
|
||||
executor.submit(download_file_from_s3, bucket, key, local_path)
|
||||
for bucket, key, local_path in download_tasks
|
||||
]
|
||||
|
||||
futures = [executor.submit(download_file_from_s3, bucket, key, local_path) for bucket, key, local_path in download_tasks]
|
||||
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Downloading"):
|
||||
future.result()
|
||||
|
||||
@ -216,15 +194,12 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non
|
||||
rel_path = os.path.relpath(local_path, source_dir)
|
||||
s3_key = os.path.join(dest_prefix, rel_path)
|
||||
upload_tasks.append((local_path, dest_bucket, s3_key))
|
||||
|
||||
|
||||
print(f"Uploading {len(upload_tasks)} files from {source_dir} to s3://{dest_bucket}/{dest_prefix}...")
|
||||
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [
|
||||
executor.submit(upload_file_to_s3, local_path, bucket, key)
|
||||
for local_path, bucket, key in upload_tasks
|
||||
]
|
||||
|
||||
futures = [executor.submit(upload_file_to_s3, local_path, bucket, key) for local_path, bucket, key in upload_tasks]
|
||||
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Uploading"):
|
||||
future.result()
|
||||
|
||||
@ -234,26 +209,26 @@ def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest
|
||||
# List all objects in source
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=source_bucket, Prefix=source_prefix)
|
||||
|
||||
|
||||
copy_tasks = []
|
||||
for page in pages:
|
||||
for obj in page.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
if key.endswith("/"):
|
||||
continue
|
||||
|
||||
|
||||
filename = os.path.basename(key)
|
||||
if filename in EXCLUDED_FILES:
|
||||
print(f"Skipping excluded file: {filename}")
|
||||
continue
|
||||
|
||||
|
||||
rel_path = os.path.relpath(key, source_prefix)
|
||||
dest_key = os.path.join(dest_prefix, rel_path)
|
||||
copy_source = {"Bucket": source_bucket, "Key": key}
|
||||
copy_tasks.append((copy_source, dest_bucket, dest_key))
|
||||
|
||||
|
||||
print(f"Copying {len(copy_tasks)} files from s3://{source_bucket}/{source_prefix} to s3://{dest_bucket}/{dest_prefix}...")
|
||||
|
||||
|
||||
for copy_source, bucket, key in tqdm(copy_tasks, desc="Copying"):
|
||||
s3_client.copy_object(CopySource=copy_source, Bucket=bucket, Key=key)
|
||||
|
||||
@ -350,10 +325,10 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
souped_path = os.path.join(souped_dir, rel_path)
|
||||
os.makedirs(os.path.dirname(souped_path), exist_ok=True)
|
||||
|
||||
if file_path.endswith('.safetensors'):
|
||||
sum_state = load_file(file_path, device='cpu')
|
||||
if file_path.endswith(".safetensors"):
|
||||
sum_state = load_file(file_path, device="cpu")
|
||||
for other_path in all_paths[1:]:
|
||||
other_state = load_file(other_path, device='cpu')
|
||||
other_state = load_file(other_path, device="cpu")
|
||||
if set(sum_state.keys()) != set(other_state.keys()):
|
||||
raise ValueError(f"Key mismatch in {rel_path}")
|
||||
for k in sum_state:
|
||||
@ -363,10 +338,10 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
for k in sum_state:
|
||||
sum_state[k] /= n
|
||||
save_file(sum_state, souped_path)
|
||||
elif file_path.endswith('.bin'):
|
||||
sum_state = torch.load(file_path, map_location='cpu')
|
||||
elif file_path.endswith(".bin"):
|
||||
sum_state = torch.load(file_path, map_location="cpu")
|
||||
for other_path in all_paths[1:]:
|
||||
other_state = torch.load(other_path, map_location='cpu')
|
||||
other_state = torch.load(other_path, map_location="cpu")
|
||||
if set(sum_state.keys()) != set(other_state.keys()):
|
||||
raise ValueError(f"Key mismatch in {rel_path}")
|
||||
for k in sum_state:
|
||||
@ -390,19 +365,16 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
|
||||
# Download tokenizer files from Hugging Face
|
||||
print("\nDownloading tokenizer files from Hugging Face...")
|
||||
|
||||
|
||||
if is_s3_path(dest_path):
|
||||
# Download to temp directory first, then upload to S3
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Download files
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
|
||||
futures = [
|
||||
executor.submit(download_file_from_hf, filename, temp_dir, hf_base_url)
|
||||
for filename in TOKENIZER_FILES
|
||||
]
|
||||
futures = [executor.submit(download_file_from_hf, filename, temp_dir, hf_base_url) for filename in TOKENIZER_FILES]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
future.result()
|
||||
|
||||
|
||||
# Upload to S3
|
||||
dest_bucket, dest_prefix = parse_s3_path(dest_path)
|
||||
upload_tasks = []
|
||||
@ -410,25 +382,19 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
local_path = os.path.join(temp_dir, filename)
|
||||
s3_key = os.path.join(dest_prefix, filename)
|
||||
upload_tasks.append((local_path, dest_bucket, s3_key))
|
||||
|
||||
|
||||
print("Uploading tokenizer files to S3...")
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
|
||||
futures = [
|
||||
executor.submit(upload_file_to_s3, local_path, bucket, key)
|
||||
for local_path, bucket, key in upload_tasks
|
||||
]
|
||||
futures = [executor.submit(upload_file_to_s3, local_path, bucket, key) for local_path, bucket, key in upload_tasks]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
future.result()
|
||||
else:
|
||||
# Download directly to destination
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
|
||||
futures = [
|
||||
executor.submit(download_file_from_hf, filename, dest_path, hf_base_url)
|
||||
for filename in TOKENIZER_FILES
|
||||
]
|
||||
futures = [executor.submit(download_file_from_hf, filename, dest_path, hf_base_url) for filename in TOKENIZER_FILES]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
future.result()
|
||||
|
||||
|
||||
print(f"\n✓ Successfully prepared checkpoint at {dest_path}")
|
||||
|
||||
|
||||
@ -436,26 +402,26 @@ def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare OlmOCR checkpoint for deployment",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__.split("Usage:")[1] # Use the docstring for epilog
|
||||
epilog=__doc__.split("Usage:")[1], # Use the docstring for epilog
|
||||
)
|
||||
parser.add_argument("paths", nargs='+', help="One or more source paths followed by destination path (local or S3)")
|
||||
|
||||
parser.add_argument("paths", nargs="+", help="One or more source paths followed by destination path (local or S3)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if len(args.paths) < 2:
|
||||
parser.error("At least one source and one destination required")
|
||||
|
||||
|
||||
sources = args.paths[:-1]
|
||||
destination = args.paths[-1]
|
||||
|
||||
|
||||
try:
|
||||
prepare_checkpoints(sources, destination)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
exit(main())
|
||||
|
||||
@ -5,6 +5,7 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -18,7 +19,6 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
from olmocr.train.config import Config
|
||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||
|
||||
@ -47,13 +47,13 @@ class QwenDataCollator:
|
||||
input_ids = torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"]
|
||||
attention_mask = torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"]
|
||||
labels = torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"]
|
||||
|
||||
|
||||
# Trim to max_token_len if specified
|
||||
if self.max_token_len is not None:
|
||||
input_ids = input_ids[:self.max_token_len]
|
||||
attention_mask = attention_mask[:self.max_token_len]
|
||||
labels = labels[:self.max_token_len]
|
||||
|
||||
input_ids = input_ids[: self.max_token_len]
|
||||
attention_mask = attention_mask[: self.max_token_len]
|
||||
labels = labels[: self.max_token_len]
|
||||
|
||||
batch["input_ids"].append(input_ids)
|
||||
batch["attention_mask"].append(attention_mask)
|
||||
batch["labels"].append(labels)
|
||||
@ -103,7 +103,7 @@ def main():
|
||||
if config.project_name:
|
||||
os.environ["WANDB_PROJECT"] = config.project_name
|
||||
logger.info(f"Setting WANDB_PROJECT to: {config.project_name}")
|
||||
|
||||
|
||||
# Load processor for tokenization
|
||||
logger.info(f"Loading processor: {config.model.name}")
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
@ -209,7 +209,7 @@ def main():
|
||||
adam_epsilon=config.training.adam_epsilon,
|
||||
weight_decay=config.training.weight_decay,
|
||||
max_grad_norm=config.training.max_grad_norm,
|
||||
bf16=True, # We're sticking with this known good reduced precision option
|
||||
bf16=True, # We're sticking with this known good reduced precision option
|
||||
eval_strategy=config.training.evaluation_strategy,
|
||||
eval_steps=config.training.eval_steps,
|
||||
save_strategy=config.training.save_strategy,
|
||||
|
||||
@ -167,7 +167,7 @@ reportPrivateImportUsage = false
|
||||
line-length = 160
|
||||
target-version = "py311"
|
||||
exclude = ["olmocr/train/molmo", "tests/*"]
|
||||
ignore = ["E722"] #igore bare except
|
||||
ignore = ["E722", "F541"] #igore bare except, and f string without placeholders
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"__init__.py" = ["F401"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user