Cleaning up scripts, multi gpu trainer more flexible

This commit is contained in:
Jake Poznanski 2025-09-03 18:25:10 +00:00
parent c612293a59
commit 94d19c51c6
12 changed files with 378 additions and 678 deletions

View File

@ -1,665 +0,0 @@
"""Benchmark offline inference throughput."""
import argparse
import base64
import json
import random
import time
from io import BytesIO
from typing import List, Optional, Tuple
import torch
import uvloop
from PIL import Image
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase,
)
from vllm import TokensPrompt
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
def sample_mm_requests_qwen2vl(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
):
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
with open(dataset_path, "r") as f:
json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
result = []
for data in tqdm(json_data):
text = processor.apply_chat_template(data["chat_messages"], tokenize=False, add_generation_prompt=True)
raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
_main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1 :])))
# Process inputs using processor
inputs = processor(
text=[text],
# images=[_main_image], # Don't pad out the image tokens yet, since that happens later inside of birr
padding=True,
return_tensors="np",
)
# print(inputs)
tokens = inputs["input_ids"][0]
prompt_len = len(tokens)
result.append(
(
TokensPrompt(
dict(
prompt_token_ids=tokens,
multi_modal_data=dict(image=dict(image_embeds=torch.randn(1036, 3584), image_grid_thw=torch.tensor([[1, 74, 56]]))),
# multi_modal_data=dict(image=main_image)
)
),
prompt_len,
fixed_output_len,
)
)
if len(result) >= num_requests:
break
return result
def sample_mm_requests_phi3(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
):
processor = AutoProcessor.from_pretrained("microsoft/Phi-3.5-vision-instruct", trust_remote_code=True)
with open(dataset_path, "r") as f:
json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
result = []
for data in tqdm(json_data):
inputs = processor.tokenizer.apply_chat_template(
[{"role": "user", "content": "<|image_1|>\n" + data["chat_messages"][0]["content"][0]["text"]}], tokenize=True, add_generation_prompt=True
)
raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1 :])))
# tokens = inputs["input_ids"][0]
tokens = inputs
prompt_len = len(tokens)
result.append(
(
TokensPrompt(
dict(
prompt_token_ids=tokens,
multi_modal_data=dict(image=main_image),
)
),
prompt_len,
fixed_output_len,
)
)
if len(result) >= num_requests:
break
return result
def sample_mm_requests_molmo(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
):
processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-D-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
with open(dataset_path, "r") as f:
json_data = [json.loads(line) for line in f.readlines() if len(line.strip()) > 0]
result = []
for data in tqdm(json_data):
raw_b64 = data["chat_messages"][0]["content"][1]["image_url"]["url"]
main_image = Image.open(BytesIO(base64.b64decode(raw_b64[raw_b64.find(",") + 1 :])))
inputs = inputs = processor.process(images=[main_image], text=data["chat_messages"][0]["content"][0]["text"])
# print(inputs)
# Molmo has max size of 4096 which is lower than our dataset was generated for
tokens = inputs["input_ids"][:2000]
# tokens = inputs
prompt_len = len(tokens)
result.append(
(
TokensPrompt(
dict(
prompt_token_ids=tokens,
multi_modal_data=dict(image=main_image),
)
),
prompt_len,
fixed_output_len,
)
)
if len(result) >= num_requests:
break
return result
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
# speculative_model="[ngram]",
# num_speculative_tokens=1,
# ngram_prompt_lookup_max=5,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
disable_async_output_proc=disable_async_output_proc,
)
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
)
)
use_beam_search = False
if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
start = time.perf_counter()
llm.beam_search(
prompts,
BeamSearchParams(
beam_width=n,
max_tokens=output_len,
ignore_eos=True,
),
)
end = time.perf_counter()
return end - start
async def run_vllm_async(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
disable_frontend_multiprocessing: bool = False,
) -> float:
from vllm import SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
disable_log_requests=True,
)
async with build_async_engine_client_from_engine_args(engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
)
)
generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start
def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.perf_counter()
batch: List[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt, prompt_len, output_len = requests[i]
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) + max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=True,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.perf_counter()
return end - start
def run_mii(
requests: List[Tuple[str, int, int]],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
start = time.perf_counter()
llm.generate(prompts, max_new_tokens=output_len)
end = time.perf_counter()
client = client(model)
client.terminate_server()
return end - start
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len) for _ in range(args.num_prompts)]
else:
# requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
# args.output_len)
requests = sample_mm_requests_qwen2vl(args.dataset, args.num_prompts, tokenizer, args.output_len)
if args.backend == "vllm":
run_args = [
requests,
args.model,
args.tokenizer,
args.quantization,
args.tensor_parallel_size,
args.seed,
args.n,
args.trust_remote_code,
args.dtype,
args.max_model_len,
args.enforce_eager,
args.kv_cache_dtype,
args.quantization_param_path,
args.device,
args.enable_prefix_caching,
args.enable_chunked_prefill,
args.max_num_batched_tokens,
args.distributed_executor_backend,
args.gpu_memory_utilization,
args.num_scheduler_steps,
args.download_dir,
args.load_format,
args.disable_async_output_proc,
]
if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time = run_vllm(*run_args)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, args.hf_max_batch_size, args.trust_remote_code)
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, args.output_len)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len for _, prompt_len, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm")
parser.add_argument("--dataset", type=str, default=None, help="Path to the dataset.")
parser.add_argument("--input-len", type=int, default=None, help="Input prompt length for each request")
parser.add_argument("--output-len", type=int, default=None, help="Output length for each request. Overrides the " "output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--quantization", "-q", choices=[*QUANTIZATION_METHODS, None], default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1, help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None, help="Maximum batch size for HF backend.")
parser.add_argument("--trust-remote-code", action="store_true", help="trust remote code from huggingface")
parser.add_argument(
"--max-model-len",
type=int,
default=None,
help="Maximum length of a sequence (including prompt and output). " "If None, will be derived from the model.",
)
parser.add_argument(
"--dtype",
type=str,
default="auto",
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
help="data type for model weights and activations. "
'The "auto" option will use FP16 precision '
"for FP32 and FP16 models, and BF16 precision "
"for BF16 models.",
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.9,
help="the fraction of GPU memory to be used for "
"the model executor, which can range from 0 to 1."
"If unspecified, will use the default value of 0.9.",
)
parser.add_argument("--enforce-eager", action="store_true", help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
)
parser.add_argument(
"--quantization-param-path",
type=str,
default=None,
help="Path to the JSON file containing the KV cache scaling factors. "
"This should generally be supplied, when KV cache dtype is FP8. "
"Otherwise, KV cache scaling factors default to 1.0, which may cause "
"accuracy issues. FP8_E5M2 (without scaling) is only supported on "
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is "
"instead supported for common inference criteria.",
)
parser.add_argument("--device", type=str, default="auto", choices=DEVICE_OPTIONS, help="device type for vLLM execution")
parser.add_argument("--num-scheduler-steps", type=int, default=1, help="Maximum number of forward steps per scheduler call.")
parser.add_argument("--enable-prefix-caching", action="store_true", help="Enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill", action="store_true", help="enable chunked prefill for vLLM backend.")
parser.add_argument("--max-num-batched-tokens", type=int, default=None, help="maximum number of batched tokens per " "iteration")
parser.add_argument(
"--download-dir", type=str, default=None, help="directory to download and load the weights, " "default to the default cache dir of huggingface"
)
parser.add_argument("--output-json", type=str, default=None, help="Path to save the throughput results in JSON format.")
parser.add_argument(
"--distributed-executor-backend",
choices=["ray", "mp"],
default=None,
help="Backend to use for distributed serving. When more than 1 GPU "
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.',
)
parser.add_argument(
"--load-format",
type=str,
default=EngineArgs.load_format,
choices=["auto", "pt", "safetensors", "npcache", "dummy", "tensorizer", "bitsandbytes"],
help="The format of the model weights to load.\n\n"
'* "auto" will try to load the weights in the safetensors format '
"and fall back to the pytorch bin format if safetensors format "
"is not available.\n"
'* "pt" will load the weights in the pytorch bin format.\n'
'* "safetensors" will load the weights in the safetensors format.\n'
'* "npcache" will load the weights in pytorch format and store '
"a numpy cache to speed up the loading.\n"
'* "dummy" will initialize the weights with random values, '
"which is mainly for profiling.\n"
'* "tensorizer" will load the weights using tensorizer from '
"CoreWeave. See the Tensorize vLLM Model script in the Examples"
"section for more information.\n"
'* "bitsandbytes" will load the weights using bitsandbytes '
"quantization.\n",
)
parser.add_argument("--disable-async-output-proc", action="store_true", default=False, help="Disable async output processor for vLLM backend.")
parser.add_argument("--async-engine", action="store_true", default=False, help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing", action="store_true", default=False, help="Disable decoupled async engine frontend.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
if args.backend == "vllm":
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
elif args.backend == "hf":
if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
elif args.backend == "mii":
if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
if args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII " "backend.")
main(args)

346
scripts/clean_olmocrmix.py Executable file
View File

@ -0,0 +1,346 @@
#!/usr/bin/env python3
# Takes a dataset location in olmocr-mix format, (ex. a nested directory structure folder/subfolder/document.md with a corresponding folder/subfolder/document.pdf)
# Then, it will randomly shuffle these (with a fixed seed), and prompt chatgpt to clean up the transcription, and output a cleaned document
# Uses structured output to get a good result, then writes things back in the same format in a new root folder, preserving the original folder structure
import argparse
import json
import os
import random
import sys
from pathlib import Path
from typing import List, Tuple, Any, Dict
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor, as_completed
from pypdf import PdfReader
from olmocr.data.renderpdf import render_pdf_to_base64png
from openai import OpenAI
from pydantic import BaseModel, Field
from tqdm import tqdm
# Structured output model for ChatGPT response
class CleanedDocument(BaseModel):
cleaned_text: str = Field(description="The cleaned and corrected version of the OCR transcription")
confidence_score: float = Field(description="Confidence score from 0 to 1 indicating how confident the model is in the cleaning", ge=0.0, le=1.0)
corrections_made: List[str] = Field(description="List of major corrections or improvements made to the text")
@dataclass
class DocumentPair:
md_path: Path
pdf_path: Path
relative_path: Path # Relative path from root for preserving structure
def parse_args():
parser = argparse.ArgumentParser(
description="Clean OCR transcriptions using ChatGPT with visual PDF context"
)
parser.add_argument(
"input_dir",
help="Input directory containing olmocr-mix format data (MD files with corresponding PDFs)"
)
parser.add_argument(
"output_dir",
help="Output directory for cleaned documents (preserves folder structure)"
)
parser.add_argument(
"--openai-api-key",
help="OpenAI API key (can also be set via OPENAI_API_KEY environment variable)",
default=os.getenv("OPENAI_API_KEY")
)
parser.add_argument(
"--model",
default="gpt-4o-mini",
help="OpenAI model to use (default: gpt-4o-mini)"
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for shuffling documents (default: 42)"
)
parser.add_argument(
"--batch-size",
type=int,
default=10,
help="Number of documents to process in parallel (default: 10)"
)
parser.add_argument(
"--max-documents",
type=int,
help="Maximum number of documents to process (useful for testing)"
)
parser.add_argument(
"--skip-existing",
action="store_true",
help="Skip documents that already have cleaned versions in the output directory"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose output"
)
return parser.parse_args()
def check_single_page_pdf(pdf_path: Path) -> bool:
"""Check if a PDF has exactly one page."""
try:
with open(pdf_path, 'rb') as pdf_file:
pdf_reader = PdfReader(pdf_file)
return len(pdf_reader.pages) == 1
except Exception as e:
print(f"Error checking PDF {pdf_path}: {e}")
return False
def find_document_pairs(input_dir: Path, verbose: bool = False) -> List[DocumentPair]:
"""Find all MD files with corresponding single-page PDF files."""
pairs = []
skipped_no_pdf = 0
skipped_multi_page = 0
for md_path in input_dir.rglob("*.md"):
# Check for corresponding PDF
pdf_path = md_path.with_suffix(".pdf")
if not pdf_path.exists():
if verbose:
print(f"Warning: No PDF found for {md_path}")
skipped_no_pdf += 1
continue
# Check if PDF has exactly one page
if not check_single_page_pdf(pdf_path):
if verbose:
print(f"Warning: Skipping multi-page PDF {pdf_path}")
skipped_multi_page += 1
continue
relative_path = md_path.relative_to(input_dir)
pairs.append(DocumentPair(md_path, pdf_path, relative_path))
if skipped_no_pdf > 0 or skipped_multi_page > 0:
print(f"Skipped {skipped_no_pdf} files without PDFs and {skipped_multi_page} multi-page PDFs")
return pairs
def render_single_page_pdf(pdf_path: Path) -> str:
"""Render a single-page PDF to base64 PNG image."""
try:
# Use render_pdf_to_base64png with target_longest_image_dim=2048
base64_png = render_pdf_to_base64png(
str(pdf_path),
1, # Always page 1 since we validated it's a single-page PDF
target_longest_image_dim=2048
)
return base64_png
except Exception as e:
raise RuntimeError(f"Could not render PDF {pdf_path}: {e}")
def clean_document_with_chatgpt(
client: OpenAI,
model: str,
md_content: str,
pdf_image: str,
verbose: bool = False
) -> CleanedDocument:
"""Use ChatGPT to clean the OCR transcription with PDF context."""
# Prepare the messages
messages: List[Dict[str, Any]] = [
{
"role": "system",
"content": (
"You are an expert at cleaning and correcting OCR transcriptions. "
"You will be given an OCR transcription and an image of the original PDF page. "
"Your task is to:\n"
"1. Fix OCR errors and typos\n"
"2. Correct formatting issues\n"
"3. Restore proper punctuation and capitalization\n"
"4. Fix word breaks and line breaks\n"
"5. Ensure mathematical formulas and special characters are correct\n"
"6. Maintain the semantic structure of the document\n"
"Return a cleaned version that accurately represents the original document."
)
}
]
# Add the content with the PDF image
content: List[Dict[str, Any]] = [
{
"type": "text",
"text": f"Please clean the following OCR transcription based on the provided PDF page image:\n\n{md_content}"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{pdf_image}"
}
}
]
messages.append({
"role": "user",
"content": content
})
# Make the API call with structured output
try:
response = client.beta.chat.completions.parse(
model=model,
messages=messages, # type: ignore
response_format=CleanedDocument,
temperature=0.2, # Lower temperature for more consistent cleaning
max_tokens=16384
)
parsed_result = response.choices[0].message.parsed
if parsed_result is None:
raise ValueError("ChatGPT returned no parsed result")
return parsed_result
except Exception as e:
print(f"Error calling ChatGPT: {e}")
raise
def process_document(
doc_pair: DocumentPair,
client: OpenAI,
model: str,
output_dir: Path,
skip_existing: bool,
verbose: bool
) -> Tuple[bool, str]:
"""Process a single document pair."""
# Check if output already exists
output_path = output_dir / doc_pair.relative_path
if skip_existing and output_path.exists():
return True, f"Skipped (already exists): {doc_pair.relative_path}"
try:
# Read the markdown content
md_content = doc_pair.md_path.read_text(encoding='utf-8')
# Render the single PDF page
pdf_image = render_single_page_pdf(doc_pair.pdf_path)
# Clean with ChatGPT
cleaned_result = clean_document_with_chatgpt(
client, model, md_content, pdf_image, verbose
)
# Create output directory if needed
output_path.parent.mkdir(parents=True, exist_ok=True)
# Write cleaned text
output_path.write_text(cleaned_result.cleaned_text, encoding='utf-8')
# Also write metadata
metadata_path = output_path.with_suffix('.json')
metadata = {
'original_md': str(doc_pair.md_path),
'original_pdf': str(doc_pair.pdf_path),
'confidence_score': cleaned_result.confidence_score,
'corrections_made': cleaned_result.corrections_made,
'model': model,
'pages_rendered': 1
}
metadata_path.write_text(json.dumps(metadata, indent=2), encoding='utf-8')
return True, f"Processed: {doc_pair.relative_path} (confidence: {cleaned_result.confidence_score:.2f})"
except Exception as e:
return False, f"Error processing {doc_pair.relative_path}: {e}"
def main():
args = parse_args()
# Validate API key
if not args.openai_api_key:
print("Error: OpenAI API key is required. Set via --openai-api-key or OPENAI_API_KEY environment variable.")
sys.exit(1)
# Initialize OpenAI client
client = OpenAI(api_key=args.openai_api_key)
# Set up paths
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir)
if not input_dir.exists():
print(f"Error: Input directory {input_dir} does not exist.")
sys.exit(1)
output_dir.mkdir(parents=True, exist_ok=True)
# Find all document pairs (single-page PDFs only)
print(f"Scanning {input_dir} for single-page document pairs...")
doc_pairs = find_document_pairs(input_dir, args.verbose)
print(f"Found {len(doc_pairs)} valid single-page document pairs.")
if not doc_pairs:
print("No document pairs found.")
return
# Shuffle with fixed seed
random.seed(args.seed)
random.shuffle(doc_pairs)
# Limit if requested
if args.max_documents:
doc_pairs = doc_pairs[:args.max_documents]
print(f"Processing first {args.max_documents} documents after shuffling.")
# Process documents in batches
successful = 0
failed = 0
with ThreadPoolExecutor(max_workers=args.batch_size) as executor:
futures = []
for doc_pair in doc_pairs:
future = executor.submit(
process_document,
doc_pair,
client,
args.model,
output_dir,
args.skip_existing,
args.verbose
)
futures.append(future)
# Process results with progress bar
with tqdm(total=len(futures), desc="Processing documents") as pbar:
for future in as_completed(futures):
success, message = future.result()
if success:
successful += 1
else:
failed += 1
if args.verbose:
tqdm.write(message)
pbar.update(1)
pbar.set_postfix({
'successful': successful,
'failed': failed
})
# Print summary
print(f"\nProcessing complete:")
print(f" Successful: {successful}")
print(f" Failed: {failed}")
print(f" Output directory: {output_dir}")
if __name__ == "__main__":
main()

View File

@ -6,6 +6,7 @@ set -e
SKIP_DOCKER_BUILD=false
PREEMPTIBLE=false
EXP_NAME=""
NUM_GPUS=4
# Store all arguments to pass to python command
PYTHON_ARGS=()
@ -24,6 +25,14 @@ while [[ $# -gt 0 ]]; do
EXP_NAME="$2"
shift 2
;;
--num-gpus)
NUM_GPUS="$2"
if [ "$NUM_GPUS" -lt 2 ] || [ "$NUM_GPUS" -gt 8 ]; then
echo "Error: --num-gpus must be between 2 and 8 (got: $NUM_GPUS)"
exit 1
fi
shift 2
;;
--help|-h)
echo "Usage: $0 [beaker-options] [grpo-training-options]"
echo ""
@ -31,13 +40,14 @@ while [[ $# -gt 0 ]]; do
echo " --skip-docker-build Skip Docker build"
echo " --preemptible Use preemptible instances"
echo " --name NAME Experiment name (used in output directory)"
echo " --num-gpus N Number of GPUs to use (2-8, default: 4)"
echo ""
echo "All other arguments are forwarded to python -m olmocr.train.grpo_train"
echo "Run 'python -m olmocr.train.grpo_train --help' to see available training options"
echo ""
echo "This multi-GPU version runs:"
echo " - VLLM server on GPU 3"
echo " - Training on GPUs 0,1,2 with DeepSpeed"
echo " - VLLM server on the last GPU"
echo " - Training on all other GPUs with DeepSpeed"
exit 0
;;
*)
@ -50,6 +60,7 @@ done
echo "Preemptible: $PREEMPTIBLE"
echo "Skip Docker Build: $SKIP_DOCKER_BUILD"
echo "Number of GPUs: $NUM_GPUS"
echo "Arguments to forward: ${PYTHON_ARGS[@]}"
# Use conda environment Python if available, otherwise use system Python
@ -109,8 +120,15 @@ git_branch = sys.argv[3]
git_hash = sys.argv[4]
preemptible = sys.argv[5] == "true"
exp_name = sys.argv[6] # Empty string if not provided
num_gpus = int(sys.argv[7])
# All remaining arguments are the python command arguments
python_args = sys.argv[7:]
python_args = sys.argv[8:]
# Calculate GPU assignments
vllm_gpu = num_gpus - 1 # Last GPU for VLLM
training_gpus = list(range(num_gpus - 1)) # All other GPUs for training
training_gpu_str = ",".join(str(g) for g in training_gpus)
num_training_processes = len(training_gpus)
# Initialize Beaker client
b = Beaker.from_env(default_workspace="ai2/olmocr")
@ -176,7 +194,7 @@ else:
# Build the GRPO training command with forwarded arguments
# Force --vllm_mode server
grpo_cmd = "CUDA_VISIBLE_DEVICES=0,1,2 accelerate launch --use_deepspeed --zero_stage 2 --num_processes 3 --gradient_accumulation_steps 8 -m olmocr.train.grpo_train"
grpo_cmd = f"CUDA_VISIBLE_DEVICES={training_gpu_str} accelerate launch --use_deepspeed --zero_stage 2 --num_processes {num_training_processes} --gradient_accumulation_steps 8 -m olmocr.train.grpo_train"
# Add --vllm_mode server if not already in arguments
arg_str = " ".join(modified_args)
@ -214,22 +232,22 @@ for i, arg in enumerate(modified_args):
grpo_cmd += " " + " ".join(filtered_args)
# Create a bash script as a single command string
bash_script = """
bash_script = f"""
set -e
# Setup commands
""" + " && ".join(setup_commands) + """
{" && ".join(setup_commands)}
# Start VLLM server in background
echo 'Starting VLLM server on GPU 3 as background process...'
CUDA_VISIBLE_DEVICES=3 nohup trl vllm-serve --model """ + vllm_model_arg + """ --port 8000 --gpu-memory-utilization 0.9 > /tmp/vllm_server.log 2>&1 &
echo 'Starting VLLM server on GPU {vllm_gpu} as background process...'
CUDA_VISIBLE_DEVICES={vllm_gpu} nohup trl vllm-serve --model {vllm_model_arg} --port 8000 --gpu-memory-utilization 0.9 > /tmp/vllm_server.log 2>&1 &
VLLM_PID=$!
echo "VLLM server started with PID: $VLLM_PID"
# Wait for VLLM server to be ready
echo 'Waiting for VLLM server to be ready...'
sleep 30
for i in {1..60}; do
for i in {{1..60}}; do
if curl -s http://localhost:8000/health; then
echo ' - VLLM server is ready!'
break
@ -240,8 +258,8 @@ for i in {1..60}; do
done
# Run training
echo 'Starting GRPO training on GPUs 0,1,2...'
""" + grpo_cmd + """
echo 'Starting GRPO training on GPUs {training_gpu_str}...'
{grpo_cmd}
# Cleanup
echo 'Training completed. Killing VLLM server...'
@ -262,7 +280,7 @@ task_spec = TaskSpec(
preemptible=preemptible,
),
resources=TaskResources(
gpu_count=4, # Request 4 GPUs total
gpu_count=num_gpus, # Request the specified number of GPUs
shared_memory="10GiB"
),
constraints=Constraints(cluster=["ai2/jupiter", "ai2/saturn"]),
@ -291,7 +309,7 @@ for i, arg in enumerate(modified_args):
# Create experiment spec with single task
experiment_spec = ExperimentSpec(
description=f"OlmOCR GRPO Multi-GPU Training (3 GPUs + VLLM Server) - Model: {model_name}, Branch: {git_branch}, Commit: {git_hash}",
description=f"OlmOCR GRPO Multi-GPU Training ({num_training_processes} GPUs + VLLM Server) - Model: {model_name}, Branch: {git_branch}, Commit: {git_hash}",
budget="ai2/oe-base",
tasks=[task_spec], # Single task that manages both VLLM and training
)
@ -311,6 +329,7 @@ $PYTHON /tmp/run_grpo_experiment_multi_gpu.py \
"$GIT_HASH" \
"$PREEMPTIBLE" \
"$EXP_NAME" \
"$NUM_GPUS" \
"${PYTHON_ARGS[@]}"
# Clean up temporary file