mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-13 08:11:22 +00:00
Lint fixes
This commit is contained in:
parent
5ec49672ea
commit
6e8272413c
@ -17,7 +17,7 @@ import time
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures.process import BrokenProcessPool
|
||||
from dataclasses import dataclass
|
||||
from functools import cache, partial
|
||||
from functools import cache
|
||||
from io import BytesIO
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@ -34,7 +34,6 @@ from olmocr.check import (
|
||||
check_torch_gpu_available,
|
||||
)
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.train.dataloader import FrontMatterParser
|
||||
from olmocr.filter.filter import Language, PdfFilter
|
||||
from olmocr.image_utils import convert_image_to_pdf_bytes, is_jpeg, is_png
|
||||
from olmocr.metrics import MetricsKeeper, WorkerTracker
|
||||
@ -48,6 +47,7 @@ from olmocr.s3_utils import (
|
||||
get_s3_bytes_with_backoff,
|
||||
parse_s3_path,
|
||||
)
|
||||
from olmocr.train.dataloader import FrontMatterParser
|
||||
from olmocr.version import VERSION
|
||||
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
|
||||
|
||||
@ -227,7 +227,9 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
|
||||
|
||||
# Enable guided decoding regex if needed
|
||||
if args.guided_decoding:
|
||||
query["guided_regex"] = r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n(?:---|---\n[\s\S]+)"
|
||||
query["guided_regex"] = (
|
||||
r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n(?:---|---\n[\s\S]+)"
|
||||
)
|
||||
|
||||
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")
|
||||
|
||||
@ -329,6 +331,7 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
|
||||
is_fallback=True,
|
||||
)
|
||||
|
||||
|
||||
async def process_pdf(args, worker_id: int, pdf_orig_path: str):
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf", delete=False) as tf:
|
||||
try:
|
||||
@ -1016,7 +1019,11 @@ async def main():
|
||||
)
|
||||
|
||||
parser.add_argument("--gpu-memory-utilization", type=float, help="Fraction of VRAM vLLM may pre-allocate for KV-cache " "(passed through to vllm serve).")
|
||||
parser.add_argument("--max_model_len", type=int, help="Upper bound (tokens) vLLM will allocate KV-cache for; " "passed through to vllm serve as --max-model-len.",)
|
||||
parser.add_argument(
|
||||
"--max_model_len",
|
||||
type=int,
|
||||
help="Upper bound (tokens) vLLM will allocate KV-cache for; " "passed through to vllm serve as --max-model-len.",
|
||||
)
|
||||
|
||||
parser.add_argument("--model_max_context", type=int, default="8192", help="Maximum context length that the model was fine tuned under")
|
||||
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288)
|
||||
@ -1227,12 +1234,12 @@ async def main():
|
||||
|
||||
# Output finished_on_attempt statistics
|
||||
logger.info("\nPages finished by attempt number:")
|
||||
total_finished = sum(total_metrics.get(f'finished_on_attempt_{i}', 0) for i in range(args.max_page_retries))
|
||||
total_finished = sum(total_metrics.get(f"finished_on_attempt_{i}", 0) for i in range(args.max_page_retries))
|
||||
cumulative = 0
|
||||
|
||||
for i in range(args.max_page_retries):
|
||||
if f'finished_on_attempt_{i}' in total_metrics:
|
||||
count = total_metrics[f'finished_on_attempt_{i}']
|
||||
if f"finished_on_attempt_{i}" in total_metrics:
|
||||
count = total_metrics[f"finished_on_attempt_{i}"]
|
||||
cumulative += count
|
||||
percentage = (count / total_finished * 100) if total_finished > 0 else 0
|
||||
cumulative_percentage = (cumulative / total_finished * 100) if total_finished > 0 else 0
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from .prompts import (
|
||||
PageResponse,
|
||||
build_finetuning_prompt,
|
||||
build_openai_silver_data_prompt,
|
||||
build_no_anchoring_yaml_prompt,
|
||||
build_openai_silver_data_prompt,
|
||||
extract_raw_text,
|
||||
openai_response_format_schema,
|
||||
)
|
||||
|
||||
@ -109,9 +109,9 @@ def build_finetuning_prompt(base_text: str) -> str:
|
||||
|
||||
def build_no_anchoring_yaml_prompt() -> str:
|
||||
return (
|
||||
f"Attached is one page of a document that you must process. "
|
||||
f"Just return the plain text representation of this document as if you were reading it naturally. Convert equations to LateX and tables to markdown.\n"
|
||||
f"Return your output as markdown, with a front matter section on top specifying values for the primary_language, is_rotation_valid, rotation_correction, is_table, and is_diagram parameters."
|
||||
"Attached is one page of a document that you must process. "
|
||||
"Just return the plain text representation of this document as if you were reading it naturally. Convert equations to LateX and tables to markdown.\n"
|
||||
"Return your output as markdown, with a front matter section on top specifying values for the primary_language, is_rotation_valid, rotation_correction, is_table, and is_diagram parameters."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -6,22 +6,21 @@ Processes prompts and images from WildVision-bench until finding significant mis
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import gc
|
||||
import os
|
||||
import glob
|
||||
import tempfile
|
||||
import shutil
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoProcessor, AutoModelForVision2Seq
|
||||
from huggingface_hub import snapshot_download
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import List, Dict
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import PIL.Image
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from olmocr.pipeline import build_page_query
|
||||
from olmocr.s3_utils import download_directory
|
||||
@ -51,7 +50,7 @@ async def download_model(model_name_or_path: str, max_retries: int = 5):
|
||||
if retry == max_retries - 1:
|
||||
raise # Raise on final attempt and fail the job
|
||||
logger.warning(f"Model download failed (attempt {retry + 1}/{max_retries}), retrying...")
|
||||
await asyncio.sleep(2 ** retry) # Exponential backoff
|
||||
await asyncio.sleep(2**retry) # Exponential backoff
|
||||
|
||||
|
||||
def image_to_base64_data_url(image):
|
||||
@ -71,14 +70,11 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
|
||||
np.random.seed(seed)
|
||||
|
||||
# Import huggingface_hub utilities to list files
|
||||
from huggingface_hub import list_repo_files, hf_hub_download
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
|
||||
# List all PDF files in the repository
|
||||
print("Listing PDF files in dataset...")
|
||||
all_files = list_repo_files(
|
||||
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
|
||||
repo_type="dataset"
|
||||
)
|
||||
all_files = list_repo_files(repo_id="allenai/olmOCR-mix-0225-benchmarkset", repo_type="dataset")
|
||||
|
||||
# Filter for PDF files in the pdfs directory
|
||||
pdf_files = [f for f in all_files if f.startswith("pdfs/") and f.endswith(".pdf")]
|
||||
@ -104,20 +100,10 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
|
||||
try:
|
||||
# Download individual PDF file
|
||||
print(f"Downloading {pdf_file}...")
|
||||
local_pdf_path = hf_hub_download(
|
||||
repo_id="allenai/olmOCR-mix-0225-benchmarkset",
|
||||
filename=pdf_file,
|
||||
repo_type="dataset",
|
||||
local_dir=temp_dir
|
||||
)
|
||||
local_pdf_path = hf_hub_download(repo_id="allenai/olmOCR-mix-0225-benchmarkset", filename=pdf_file, repo_type="dataset", local_dir=temp_dir)
|
||||
|
||||
# Build page query for page 1 of each PDF
|
||||
query = await build_page_query(
|
||||
local_pdf_path=local_pdf_path,
|
||||
page=1,
|
||||
target_longest_image_dim=1280,
|
||||
image_rotation=0
|
||||
)
|
||||
query = await build_page_query(local_pdf_path=local_pdf_path, page=1, target_longest_image_dim=1280, image_rotation=0)
|
||||
queries.append(query)
|
||||
except Exception as e:
|
||||
print(f"Error processing {os.path.basename(pdf_file)}: {e}")
|
||||
@ -126,27 +112,28 @@ async def load_pdf_prompts(num_samples: int = 100, seed: int = 42, max_length: i
|
||||
print(f"Successfully processed {len(queries)} PDFs")
|
||||
return queries
|
||||
|
||||
|
||||
def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, sampling_params, device, args):
|
||||
"""Process a single prompt with image and return comparison results."""
|
||||
# Track if we found the first mismatch for max_prob_first_diff
|
||||
found_first_mismatch = False
|
||||
max_prob_first_diff = 0.0
|
||||
# Extract messages from the sample (which is the output of build_page_query)
|
||||
messages = sample['messages']
|
||||
messages = sample["messages"]
|
||||
|
||||
# Extract the text prompt and image from the messages
|
||||
user_message = messages[0]
|
||||
text_prompt = None
|
||||
image_base64 = None
|
||||
|
||||
for content in user_message['content']:
|
||||
if content['type'] == 'text':
|
||||
text_prompt = content['text']
|
||||
elif content['type'] == 'image_url':
|
||||
image_url = content['image_url']['url']
|
||||
for content in user_message["content"]:
|
||||
if content["type"] == "text":
|
||||
text_prompt = content["text"]
|
||||
elif content["type"] == "image_url":
|
||||
image_url = content["image_url"]["url"]
|
||||
# Extract base64 data after the comma
|
||||
if ',' in image_url:
|
||||
image_base64 = image_url.split(',')[1]
|
||||
if "," in image_url:
|
||||
image_base64 = image_url.split(",")[1]
|
||||
else:
|
||||
image_base64 = image_url
|
||||
|
||||
@ -172,7 +159,11 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
prompt_token_ids = output.prompt_token_ids
|
||||
generated_token_ids = output.outputs[0].token_ids
|
||||
|
||||
print(f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids[:10]}..." if len(prompt_token_ids) > 10 else f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids}")
|
||||
print(
|
||||
f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids[:10]}..."
|
||||
if len(prompt_token_ids) > 10
|
||||
else f"Prompt tokens ({len(prompt_token_ids)}): {prompt_token_ids}"
|
||||
)
|
||||
print(f"Generated tokens ({len(generated_token_ids)}): {generated_token_ids}")
|
||||
print(f"Generated text: {processor.decode(generated_token_ids, skip_special_tokens=True)}")
|
||||
|
||||
@ -182,21 +173,9 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
# HuggingFace forward pass
|
||||
print("\n=== HuggingFace Forward Pass ===")
|
||||
# Prepare inputs for HF model using the extracted image and text
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": text_prompt}
|
||||
]
|
||||
}
|
||||
]
|
||||
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": text_prompt}]}]
|
||||
hf_text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
||||
inputs = processor(
|
||||
text=[hf_text_prompt],
|
||||
images=[image],
|
||||
return_tensors="pt"
|
||||
).to(device)
|
||||
inputs = processor(text=[hf_text_prompt], images=[image], return_tensors="pt").to(device)
|
||||
|
||||
print("INPUTS", inputs)
|
||||
|
||||
@ -224,7 +203,7 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
|
||||
# Compare ALL tokens (prompt + generated)
|
||||
for pos, token_id in enumerate(all_token_ids):
|
||||
token_str = processor.decode([token_id], skip_special_tokens=False).replace('\n', '\\n').replace('\r', '\\r')
|
||||
token_str = processor.decode([token_id], skip_special_tokens=False).replace("\n", "\\n").replace("\r", "\\r")
|
||||
|
||||
# Determine if this is a prompt or generated token
|
||||
is_prompt = pos < len(prompt_token_ids)
|
||||
@ -265,7 +244,7 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
# Decode HF argmax token (only show if mismatch)
|
||||
hf_token_str = ""
|
||||
if token_id != hf_argmax:
|
||||
hf_token_str = processor.decode([hf_argmax], skip_special_tokens=False).replace('\n', '\\n').replace('\r', '\\r')
|
||||
hf_token_str = processor.decode([hf_argmax], skip_special_tokens=False).replace("\n", "\\n").replace("\r", "\\r")
|
||||
|
||||
print(f"{pos:>4} {token_id:>8} {token_str:>20} {token_type:>8} {vllm_prob_str:>12} {hf_argmax:>10} {hf_prob:>12.6f} {match:>6} {hf_token_str:>20}")
|
||||
else:
|
||||
@ -296,27 +275,21 @@ def process_single_prompt(sample: Dict[str, any], llm, hf_model, processor, samp
|
||||
print("No mismatches found in generated tokens")
|
||||
|
||||
return {
|
||||
'first_mismatch_idx': first_mismatch_idx,
|
||||
'max_prob_first_diff': max_prob_first_diff,
|
||||
'match_rate': match_rate,
|
||||
'num_generated': len(generated_token_ids)
|
||||
"first_mismatch_idx": first_mismatch_idx,
|
||||
"max_prob_first_diff": max_prob_first_diff,
|
||||
"match_rate": match_rate,
|
||||
"num_generated": len(generated_token_ids),
|
||||
}
|
||||
|
||||
|
||||
async def async_main():
|
||||
parser = argparse.ArgumentParser(description="Batch compare VLM inference between vLLM and HuggingFace")
|
||||
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
help="Model name or path")
|
||||
parser.add_argument("--max-tokens", type=int, default=20,
|
||||
help="Maximum tokens to generate per prompt")
|
||||
parser.add_argument("--temperature", type=float, default=0.0,
|
||||
help="Sampling temperature")
|
||||
parser.add_argument("--num-prompts", type=int, default=100,
|
||||
help="Number of prompts to load from WildVision")
|
||||
parser.add_argument("--prob-threshold", type=float, default=0.20,
|
||||
help="Probability difference threshold to stop")
|
||||
parser.add_argument("--seed", type=int, default=42,
|
||||
help="Random seed for prompt selection")
|
||||
parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct", help="Model name or path")
|
||||
parser.add_argument("--max-tokens", type=int, default=20, help="Maximum tokens to generate per prompt")
|
||||
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
|
||||
parser.add_argument("--num-prompts", type=int, default=100, help="Number of prompts to load from WildVision")
|
||||
parser.add_argument("--prob-threshold", type=float, default=0.20, help="Probability difference threshold to stop")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed for prompt selection")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Model: {args.model}")
|
||||
@ -335,22 +308,13 @@ async def async_main():
|
||||
print("\n=== Loading HuggingFace Model ===")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
processor_hf = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
hf_model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
hf_model = AutoModelForVision2Seq.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto")
|
||||
hf_model.eval()
|
||||
|
||||
# Create vLLM engine once
|
||||
print("\n=== Creating vLLM Engine ===")
|
||||
llm = LLM(model=model_path, trust_remote_code=True, gpu_memory_utilization=0.5)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
max_tokens=args.max_tokens,
|
||||
logprobs=1 # Get top-1 logprobs
|
||||
)
|
||||
sampling_params = SamplingParams(temperature=args.temperature, max_tokens=args.max_tokens, logprobs=1) # Get top-1 logprobs
|
||||
|
||||
# Process samples until finding significant mismatch
|
||||
print("\n=== Processing Samples ===")
|
||||
@ -367,7 +331,7 @@ async def async_main():
|
||||
all_results.append(result)
|
||||
|
||||
# Check if we found significant mismatch
|
||||
if result['first_mismatch_idx'] is not None and result['max_prob_first_diff'] > args.prob_threshold:
|
||||
if result["first_mismatch_idx"] is not None and result["max_prob_first_diff"] > args.prob_threshold:
|
||||
print(f"\n\n{'*'*80}")
|
||||
print(f"*** FOUND SIGNIFICANT MISMATCH ***")
|
||||
print(f"*** First mismatch probability difference: {result['max_prob_first_diff']:.6f} > {args.prob_threshold} ***")
|
||||
@ -380,20 +344,20 @@ async def async_main():
|
||||
print(f"{'='*80}")
|
||||
|
||||
total_samples = len(all_results)
|
||||
samples_with_mismatches = sum(1 for r in all_results if r['first_mismatch_idx'] is not None)
|
||||
total_tokens_generated = sum(r['num_generated'] for r in all_results)
|
||||
samples_with_mismatches = sum(1 for r in all_results if r["first_mismatch_idx"] is not None)
|
||||
total_tokens_generated = sum(r["num_generated"] for r in all_results)
|
||||
|
||||
print(f"Total samples processed: {total_samples}")
|
||||
print(f"Samples with mismatches: {samples_with_mismatches} ({samples_with_mismatches/total_samples*100:.1f}%)")
|
||||
print(f"Total tokens generated: {total_tokens_generated}")
|
||||
|
||||
if samples_with_mismatches > 0:
|
||||
avg_match_rate = sum(r['match_rate'] for r in all_results) / total_samples
|
||||
max_prob_diffs = [r['max_prob_first_diff'] for r in all_results if r['first_mismatch_idx'] is not None]
|
||||
avg_match_rate = sum(r["match_rate"] for r in all_results) / total_samples
|
||||
max_prob_diffs = [r["max_prob_first_diff"] for r in all_results if r["first_mismatch_idx"] is not None]
|
||||
avg_prob_diff = sum(max_prob_diffs) / len(max_prob_diffs)
|
||||
max_prob_diff_overall = max(max_prob_diffs)
|
||||
|
||||
first_mismatch_positions = [r['first_mismatch_idx'] for r in all_results if r['first_mismatch_idx'] is not None]
|
||||
first_mismatch_positions = [r["first_mismatch_idx"] for r in all_results if r["first_mismatch_idx"] is not None]
|
||||
avg_first_mismatch_pos = sum(first_mismatch_positions) / len(first_mismatch_positions)
|
||||
|
||||
print(f"\nMismatch Statistics:")
|
||||
|
||||
@ -26,18 +26,22 @@ import shutil
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union, List
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import boto3
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from llmcompressor import oneshot
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
from olmocr.pipeline import build_page_query
|
||||
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
@ -62,7 +66,7 @@ def get_calibration_pdfs(num_samples: int, pdf_paths: List[str]) -> List[str]:
|
||||
# Verify all PDFs exist
|
||||
valid_paths = []
|
||||
for path in pdf_paths:
|
||||
if os.path.exists(path) and path.endswith('.pdf'):
|
||||
if os.path.exists(path) and path.endswith(".pdf"):
|
||||
valid_paths.append(path)
|
||||
else:
|
||||
print(f" Warning: Skipping invalid path: {path}")
|
||||
@ -101,9 +105,7 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Datase
|
||||
images.append(image)
|
||||
|
||||
# Apply chat template to get the text
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
# Process with tokenizer
|
||||
inputs = processor(
|
||||
@ -123,7 +125,7 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Datase
|
||||
all_datasets = []
|
||||
|
||||
for i in range(0, len(dataset_items), batch_size):
|
||||
batch = dataset_items[i:i + batch_size]
|
||||
batch = dataset_items[i : i + batch_size]
|
||||
# Flatten the batch into a dict of lists
|
||||
batch_dict = {}
|
||||
for key in batch[0].keys():
|
||||
@ -138,6 +140,7 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> Datase
|
||||
return all_datasets[0]
|
||||
else:
|
||||
from datasets import concatenate_datasets
|
||||
|
||||
return concatenate_datasets(all_datasets)
|
||||
else:
|
||||
return Dataset.from_dict({})
|
||||
@ -185,7 +188,9 @@ def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
|
||||
print(f" Uploaded {rel_path}")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
|
||||
def load_model_and_tokenizer(
|
||||
source_path: str,
|
||||
) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
|
||||
"""Load model and tokenizer from source path (local or S3)."""
|
||||
if is_s3_path(source_path):
|
||||
# Download from S3 to temporary directory
|
||||
@ -210,35 +215,19 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditio
|
||||
# Load appropriate model class based on name
|
||||
if "Qwen2.5-VL" in model_name:
|
||||
print("Detected Qwen2.5-VL model")
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
|
||||
elif "Qwen2-VL" in model_name:
|
||||
print("Detected Qwen2-VL model")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
|
||||
else:
|
||||
# Default to checking architectures list
|
||||
architectures = config.get("architectures", [])
|
||||
if "Qwen2_5_VLForConditionalGeneration" in architectures:
|
||||
print("Detected Qwen2.5-VL model from architectures")
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
|
||||
else:
|
||||
print("Detected Qwen2-VL model from architectures")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype="auto")
|
||||
|
||||
print(f"Loading tokenizer from {model_path}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
@ -268,7 +257,9 @@ def data_collator(batch):
|
||||
return {key: torch.tensor(value) for key, value in batch[0].items()}
|
||||
|
||||
|
||||
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 512, calibration_pdfs: Optional[List[str]] = None) -> None:
|
||||
def compress_checkpoint(
|
||||
source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 512, calibration_pdfs: Optional[List[str]] = None
|
||||
) -> None:
|
||||
"""Compress OlmOCR checkpoint using FP8 quantization."""
|
||||
# Load model and tokenizer
|
||||
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
|
||||
@ -304,14 +295,7 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_
|
||||
print(f"\nApplying quantization using recipe: {recipe_path}")
|
||||
|
||||
if dataset:
|
||||
oneshot(
|
||||
model=model,
|
||||
recipe=recipe_path,
|
||||
dataset=dataset,
|
||||
max_seq_length=8192,
|
||||
num_calibration_samples=len(dataset),
|
||||
data_collator=data_collator
|
||||
)
|
||||
oneshot(model=model, recipe=recipe_path, dataset=dataset, max_seq_length=8192, num_calibration_samples=len(dataset), data_collator=data_collator)
|
||||
else:
|
||||
oneshot(model=model, recipe=recipe_path)
|
||||
|
||||
@ -377,15 +361,18 @@ Examples:
|
||||
|
||||
# Using recursive glob pattern
|
||||
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/**/*.pdf"
|
||||
"""
|
||||
""",
|
||||
)
|
||||
parser.add_argument("source", help="Source checkpoint path (local or S3)")
|
||||
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
|
||||
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
|
||||
parser.add_argument("--num-calibration-samples", type=int, default=512,
|
||||
help="Number of calibration samples to use (default: 512, set to 0 to disable)")
|
||||
parser.add_argument("--calibration-pdfs", type=str, default=None,
|
||||
help="Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/*.pdf' or '/data/**/*.pdf'). Required when num-calibration-samples > 0.")
|
||||
parser.add_argument("--num-calibration-samples", type=int, default=512, help="Number of calibration samples to use (default: 512, set to 0 to disable)")
|
||||
parser.add_argument(
|
||||
"--calibration-pdfs",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/*.pdf' or '/data/**/*.pdf'). Required when num-calibration-samples > 0.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -396,29 +383,27 @@ Examples:
|
||||
pattern = args.calibration_pdfs
|
||||
|
||||
# Handle both absolute and relative paths with recursive glob
|
||||
if '**' in pattern:
|
||||
if "**" in pattern:
|
||||
# For recursive patterns, we need to handle them specially
|
||||
if pattern.startswith('/'):
|
||||
if pattern.startswith("/"):
|
||||
# Absolute path with **
|
||||
parts = pattern.split('**')
|
||||
parts = pattern.split("**")
|
||||
base_path = Path(parts[0])
|
||||
glob_pattern = '**' + parts[1] if len(parts) > 1 else '**/*.pdf'
|
||||
calibration_pdfs = list(base_path.glob(glob_pattern.lstrip('/')))
|
||||
glob_pattern = "**" + parts[1] if len(parts) > 1 else "**/*.pdf"
|
||||
calibration_pdfs = list(base_path.glob(glob_pattern.lstrip("/")))
|
||||
else:
|
||||
# Relative path with **
|
||||
calibration_pdfs = list(Path('.').glob(pattern))
|
||||
calibration_pdfs = [str(p.absolute()) for p in calibration_pdfs if p.suffix.lower() == '.pdf']
|
||||
calibration_pdfs = list(Path(".").glob(pattern))
|
||||
calibration_pdfs = [str(p.absolute()) for p in calibration_pdfs if p.suffix.lower() == ".pdf"]
|
||||
else:
|
||||
# Use standard glob for non-recursive patterns
|
||||
calibration_pdfs = glob.glob(pattern)
|
||||
calibration_pdfs = [p for p in calibration_pdfs if p.lower().endswith('.pdf')]
|
||||
calibration_pdfs = [p for p in calibration_pdfs if p.lower().endswith(".pdf")]
|
||||
|
||||
print(f"Found {len(calibration_pdfs)} PDF files matching pattern: {args.calibration_pdfs}")
|
||||
|
||||
|
||||
compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples, calibration_pdfs)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
|
||||
@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -314,10 +314,10 @@ class Config:
|
||||
FrontMatterOutputFormat,
|
||||
FrontMatterParser,
|
||||
InstructUserMessages,
|
||||
JSONOutputFormat,
|
||||
LatexBracketNormalizer,
|
||||
NewYamlFinetuningPromptWithAnchoring,
|
||||
NewYamlFinetuningPromptWithNoAnchoring,
|
||||
JSONOutputFormat,
|
||||
PDFRenderer,
|
||||
RandomTokenFlipper,
|
||||
StaticLengthDocumentAnchoring,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
@ -8,8 +9,7 @@ from functools import reduce
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
import json
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeAlias, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
@ -60,6 +60,7 @@ def validate_pdf_pair(md_path: Path) -> Tuple[Optional[Dict[str, Path]], Optiona
|
||||
|
||||
# Test that document anchoring works
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
|
||||
get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100)
|
||||
|
||||
return {"markdown_path": md_path, "pdf_path": pdf_path}, None
|
||||
@ -205,11 +206,11 @@ class FrontMatterParser(PipelineStep):
|
||||
value = front_matter_dict[field_name]
|
||||
|
||||
# Handle type conversions
|
||||
if field_type == int and isinstance(value, str):
|
||||
if field_type is int and isinstance(value, str):
|
||||
kwargs[field_name] = int(value)
|
||||
elif field_type == bool and isinstance(value, str):
|
||||
elif field_type is bool and isinstance(value, str):
|
||||
kwargs[field_name] = value.lower() == "true"
|
||||
elif field_type == Optional[str]:
|
||||
elif field_type is Optional[str]:
|
||||
kwargs[field_name] = value if value else None
|
||||
else:
|
||||
kwargs[field_name] = value
|
||||
@ -323,7 +324,7 @@ class FrontMatterOutputFormat(PipelineStep):
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
page_data = sample["page_data"]
|
||||
assert type(page_data) == PageResponse
|
||||
assert type(page_data) is PageResponse
|
||||
|
||||
sample["response"] = (
|
||||
f"""---
|
||||
@ -346,19 +347,23 @@ class JSONOutputFormat(PipelineStep):
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
page_data = sample["page_data"]
|
||||
assert type(page_data) == PageResponse
|
||||
assert type(page_data) is PageResponse
|
||||
|
||||
sample["response"] = json.dumps({
|
||||
"primary_language": page_data.primary_language,
|
||||
"is_rotation_valid": page_data.is_rotation_valid,
|
||||
"rotation_correction": page_data.rotation_correction,
|
||||
"is_table": page_data.is_table,
|
||||
"is_diagram": page_data.is_diagram,
|
||||
"natural_text": page_data.natural_text
|
||||
}, ensure_ascii=False)
|
||||
sample["response"] = json.dumps(
|
||||
{
|
||||
"primary_language": page_data.primary_language,
|
||||
"is_rotation_valid": page_data.is_rotation_valid,
|
||||
"rotation_correction": page_data.rotation_correction,
|
||||
"is_table": page_data.is_table,
|
||||
"is_diagram": page_data.is_diagram,
|
||||
"natural_text": page_data.natural_text,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LatexBracketNormalizer(PipelineStep):
|
||||
"""Normalizes LaTeX brackets in natural text field."""
|
||||
@ -379,7 +384,7 @@ class LatexBracketNormalizer(PipelineStep):
|
||||
# Order matters: process display math first, then inline
|
||||
patterns = [
|
||||
(r"\$\$(.+?)\$\$", r"\[\1\]"), # $$...$$ to \[...\]
|
||||
(r"\$(.+?)\$", r"\(\1\)"), # $...$ to \(...\)
|
||||
(r"\$(.+?)\$", r"\(\1\)"), # $...$ to \(...\)
|
||||
]
|
||||
|
||||
# Apply replacements
|
||||
@ -389,13 +394,14 @@ class LatexBracketNormalizer(PipelineStep):
|
||||
# Update the page_data with normalized text
|
||||
# Since PageResponse is frozen, we need to create a new instance
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
|
||||
new_page_data = PageResponse(
|
||||
primary_language=page_data.primary_language,
|
||||
is_rotation_valid=page_data.is_rotation_valid,
|
||||
rotation_correction=page_data.rotation_correction,
|
||||
is_table=page_data.is_table,
|
||||
is_diagram=page_data.is_diagram,
|
||||
natural_text=text
|
||||
natural_text=text,
|
||||
)
|
||||
|
||||
sample["page_data"] = new_page_data
|
||||
@ -608,6 +614,7 @@ if __name__ == "__main__":
|
||||
# Load processor for tokenization
|
||||
print(f"\nLoading processor: {config.model.name}")
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(config.model.name)
|
||||
|
||||
# Select dataset based on type
|
||||
@ -655,23 +662,23 @@ if __name__ == "__main__":
|
||||
print("\nSample keys:", list(sample.keys()))
|
||||
|
||||
# If it's raw data (no tokenization)
|
||||
if 'markdown_path' in sample:
|
||||
if "markdown_path" in sample:
|
||||
print(f"\nMarkdown file: {sample['markdown_path'].name}")
|
||||
if 'pdf_path' in sample:
|
||||
if "pdf_path" in sample:
|
||||
print(f"PDF file: {sample['pdf_path'].name}")
|
||||
if 'image' in sample and hasattr(sample['image'], 'size'):
|
||||
if "image" in sample and hasattr(sample["image"], "size"):
|
||||
print(f"Image size: {sample['image'].size}")
|
||||
if 'page_data' in sample:
|
||||
if "page_data" in sample:
|
||||
print(f"\nPage data: {sample['page_data']}")
|
||||
if 'messages' in sample:
|
||||
if "messages" in sample:
|
||||
print(f"\n=== Messages ===")
|
||||
for i, msg in enumerate(sample['messages']):
|
||||
for i, msg in enumerate(sample["messages"]):
|
||||
print(f"\nMessage {i}:")
|
||||
print(f" Role: {msg['role']}")
|
||||
print(f" Content preview: {str(msg['content'])[:200]}...")
|
||||
|
||||
# If it's tokenized data
|
||||
if 'input_ids' in sample:
|
||||
if "input_ids" in sample:
|
||||
print(f"\n=== Tokenized Output ===")
|
||||
print(f" Keys: {list(sample.keys())}")
|
||||
print(f" Input IDs shape: {sample['input_ids'].shape}")
|
||||
@ -749,7 +756,7 @@ if __name__ == "__main__":
|
||||
print(f" Total sequence length: {total_count}")
|
||||
|
||||
# Analyze token length distribution across entire dataset
|
||||
if 'input_ids' in sample:
|
||||
if "input_ids" in sample:
|
||||
print(f"\n\n=== Analyzing token length distribution across entire dataset ===")
|
||||
print(f"Processing {len(dataset)} samples...")
|
||||
|
||||
@ -757,9 +764,9 @@ if __name__ == "__main__":
|
||||
def process_sample(idx):
|
||||
try:
|
||||
current_sample = dataset[idx]
|
||||
if 'labels' in current_sample:
|
||||
if "labels" in current_sample:
|
||||
# Count total sequence length (all tokens, prompt + completion)
|
||||
labels = current_sample['labels']
|
||||
labels = current_sample["labels"]
|
||||
total_length = len(labels)
|
||||
return (idx, total_length, None)
|
||||
return (idx, None, "No labels in sample")
|
||||
@ -774,6 +781,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Determine number of workers (use fewer workers to avoid memory issues)
|
||||
import multiprocessing
|
||||
|
||||
num_workers = min(multiprocessing.cpu_count() // 2, 8)
|
||||
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
||||
|
||||
@ -38,9 +38,9 @@ import tempfile
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
import torch
|
||||
from smart_open import smart_open
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
from safetensors.torch import load_file, save_file
|
||||
@ -50,32 +50,16 @@ except ImportError:
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
|
||||
# Hugging Face model IDs for tokenizer files
|
||||
HF_MODEL_IDS = {
|
||||
"Qwen2VLForConditionalGeneration": "Qwen/Qwen2-VL-7B-Instruct",
|
||||
"Qwen2_5_VLForConditionalGeneration": "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
}
|
||||
HF_MODEL_IDS = {"Qwen2VLForConditionalGeneration": "Qwen/Qwen2-VL-7B-Instruct", "Qwen2_5_VLForConditionalGeneration": "Qwen/Qwen2.5-VL-7B-Instruct"}
|
||||
|
||||
# Required tokenizer files to download from Hugging Face
|
||||
TOKENIZER_FILES = [
|
||||
"chat_template.json",
|
||||
"merges.txt",
|
||||
"preprocessor_config.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"vocab.json"
|
||||
]
|
||||
TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json", "tokenizer.json", "tokenizer_config.json", "vocab.json"]
|
||||
|
||||
# Supported model architectures
|
||||
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
|
||||
|
||||
# Files to exclude from copying (training-related files)
|
||||
EXCLUDED_FILES = {
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"rng_state.pth",
|
||||
"trainer_state.json",
|
||||
"training_args.bin"
|
||||
}
|
||||
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin"}
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
@ -125,10 +109,7 @@ def detect_checkpoint_architecture(config_path: str) -> str:
|
||||
elif "Qwen2-VL" in model_name:
|
||||
detected_architecture = "Qwen2VLForConditionalGeneration"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No supported architecture found. Expected one of {SUPPORTED_ARCHITECTURES} "
|
||||
f"but found: {architectures}"
|
||||
)
|
||||
raise ValueError(f"No supported architecture found. Expected one of {SUPPORTED_ARCHITECTURES} " f"but found: {architectures}")
|
||||
|
||||
print(f"✓ Detected architecture: {detected_architecture}")
|
||||
return detected_architecture
|
||||
@ -194,10 +175,7 @@ def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> N
|
||||
print(f"Downloading {len(download_tasks)} files from s3://{source_bucket}/{source_prefix} to {dest_dir}...")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [
|
||||
executor.submit(download_file_from_s3, bucket, key, local_path)
|
||||
for bucket, key, local_path in download_tasks
|
||||
]
|
||||
futures = [executor.submit(download_file_from_s3, bucket, key, local_path) for bucket, key, local_path in download_tasks]
|
||||
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Downloading"):
|
||||
future.result()
|
||||
@ -220,10 +198,7 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non
|
||||
print(f"Uploading {len(upload_tasks)} files from {source_dir} to s3://{dest_bucket}/{dest_prefix}...")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [
|
||||
executor.submit(upload_file_to_s3, local_path, bucket, key)
|
||||
for local_path, bucket, key in upload_tasks
|
||||
]
|
||||
futures = [executor.submit(upload_file_to_s3, local_path, bucket, key) for local_path, bucket, key in upload_tasks]
|
||||
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Uploading"):
|
||||
future.result()
|
||||
@ -350,10 +325,10 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
souped_path = os.path.join(souped_dir, rel_path)
|
||||
os.makedirs(os.path.dirname(souped_path), exist_ok=True)
|
||||
|
||||
if file_path.endswith('.safetensors'):
|
||||
sum_state = load_file(file_path, device='cpu')
|
||||
if file_path.endswith(".safetensors"):
|
||||
sum_state = load_file(file_path, device="cpu")
|
||||
for other_path in all_paths[1:]:
|
||||
other_state = load_file(other_path, device='cpu')
|
||||
other_state = load_file(other_path, device="cpu")
|
||||
if set(sum_state.keys()) != set(other_state.keys()):
|
||||
raise ValueError(f"Key mismatch in {rel_path}")
|
||||
for k in sum_state:
|
||||
@ -363,10 +338,10 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
for k in sum_state:
|
||||
sum_state[k] /= n
|
||||
save_file(sum_state, souped_path)
|
||||
elif file_path.endswith('.bin'):
|
||||
sum_state = torch.load(file_path, map_location='cpu')
|
||||
elif file_path.endswith(".bin"):
|
||||
sum_state = torch.load(file_path, map_location="cpu")
|
||||
for other_path in all_paths[1:]:
|
||||
other_state = torch.load(other_path, map_location='cpu')
|
||||
other_state = torch.load(other_path, map_location="cpu")
|
||||
if set(sum_state.keys()) != set(other_state.keys()):
|
||||
raise ValueError(f"Key mismatch in {rel_path}")
|
||||
for k in sum_state:
|
||||
@ -396,10 +371,7 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Download files
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
|
||||
futures = [
|
||||
executor.submit(download_file_from_hf, filename, temp_dir, hf_base_url)
|
||||
for filename in TOKENIZER_FILES
|
||||
]
|
||||
futures = [executor.submit(download_file_from_hf, filename, temp_dir, hf_base_url) for filename in TOKENIZER_FILES]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
future.result()
|
||||
|
||||
@ -413,19 +385,13 @@ def prepare_checkpoints(sources: list[str], dest_path: str) -> None:
|
||||
|
||||
print("Uploading tokenizer files to S3...")
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
|
||||
futures = [
|
||||
executor.submit(upload_file_to_s3, local_path, bucket, key)
|
||||
for local_path, bucket, key in upload_tasks
|
||||
]
|
||||
futures = [executor.submit(upload_file_to_s3, local_path, bucket, key) for local_path, bucket, key in upload_tasks]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
future.result()
|
||||
else:
|
||||
# Download directly to destination
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=6) as executor:
|
||||
futures = [
|
||||
executor.submit(download_file_from_hf, filename, dest_path, hf_base_url)
|
||||
for filename in TOKENIZER_FILES
|
||||
]
|
||||
futures = [executor.submit(download_file_from_hf, filename, dest_path, hf_base_url) for filename in TOKENIZER_FILES]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
future.result()
|
||||
|
||||
@ -436,9 +402,9 @@ def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare OlmOCR checkpoint for deployment",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__.split("Usage:")[1] # Use the docstring for epilog
|
||||
epilog=__doc__.split("Usage:")[1], # Use the docstring for epilog
|
||||
)
|
||||
parser.add_argument("paths", nargs='+', help="One or more source paths followed by destination path (local or S3)")
|
||||
parser.add_argument("paths", nargs="+", help="One or more source paths followed by destination path (local or S3)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -18,7 +19,6 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
from olmocr.train.config import Config
|
||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||
|
||||
@ -50,9 +50,9 @@ class QwenDataCollator:
|
||||
|
||||
# Trim to max_token_len if specified
|
||||
if self.max_token_len is not None:
|
||||
input_ids = input_ids[:self.max_token_len]
|
||||
attention_mask = attention_mask[:self.max_token_len]
|
||||
labels = labels[:self.max_token_len]
|
||||
input_ids = input_ids[: self.max_token_len]
|
||||
attention_mask = attention_mask[: self.max_token_len]
|
||||
labels = labels[: self.max_token_len]
|
||||
|
||||
batch["input_ids"].append(input_ids)
|
||||
batch["attention_mask"].append(attention_mask)
|
||||
@ -209,7 +209,7 @@ def main():
|
||||
adam_epsilon=config.training.adam_epsilon,
|
||||
weight_decay=config.training.weight_decay,
|
||||
max_grad_norm=config.training.max_grad_norm,
|
||||
bf16=True, # We're sticking with this known good reduced precision option
|
||||
bf16=True, # We're sticking with this known good reduced precision option
|
||||
eval_strategy=config.training.evaluation_strategy,
|
||||
eval_steps=config.training.eval_steps,
|
||||
save_strategy=config.training.save_strategy,
|
||||
|
||||
@ -167,7 +167,7 @@ reportPrivateImportUsage = false
|
||||
line-length = 160
|
||||
target-version = "py311"
|
||||
exclude = ["olmocr/train/molmo", "tests/*"]
|
||||
ignore = ["E722"] #igore bare except
|
||||
ignore = ["E722", "F541"] #igore bare except, and f string without placeholders
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"__init__.py" = ["F401"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user