fixed style

This commit is contained in:
aman-17 2025-02-25 08:57:02 -08:00
parent c2b54d8525
commit 0130a970c2
22 changed files with 113 additions and 162 deletions

View File

@ -12,14 +12,15 @@ The final score is averaged over the repeated generations.
"""
import argparse
import os
import json
import glob
import sys
import itertools
import json
import os
import sys
from rapidfuzz import fuzz
from fuzzysearch import find_near_matches
from rapidfuzz import fuzz
def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]):
"""
@ -30,7 +31,7 @@ def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]):
rules = []
rule_ids = set()
with open(jsonl_path, 'r', encoding='utf-8') as f:
with open(jsonl_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, start=1):
line = line.strip()
if not line:
@ -75,6 +76,7 @@ def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]):
rules.append(data)
return rules
def run_rule(rule, md_file_path: str) -> (bool, str):
"""
Run the given rule on the content of the provided .md file.
@ -82,7 +84,7 @@ def run_rule(rule, md_file_path: str) -> (bool, str):
and 'explanation' is a short message explaining the failure when the rule does not pass.
"""
try:
with open(md_file_path, 'r', encoding='utf-8') as f:
with open(md_file_path, "r", encoding="utf-8") as f:
md_content = f.read()
except Exception as e:
return (False, f"Error reading {md_file_path}: {e}")
@ -121,6 +123,7 @@ def run_rule(rule, md_file_path: str) -> (bool, str):
else:
raise NotImplementedError(f"Rule type '{rule_type}' is not implemented.")
def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: list[str]):
"""
For the candidate folder (pipeline tool output), validate that it contains at least one .md file
@ -148,9 +151,7 @@ def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: li
md_pattern = os.path.join(candidate_folder, f"{md_base}_*.md")
md_files = glob.glob(md_pattern)
if not md_files:
candidate_errors.append(
f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} (expected files matching {md_base}_*.md)."
)
candidate_errors.append(f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} (expected files matching {md_base}_*.md).")
else:
pdf_to_md_files[pdf_name] = md_files
@ -195,11 +196,14 @@ def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: li
overall_score = total_rule_score / len(all_rules) if all_rules else 0.0
return (overall_score, len(all_rules), candidate_errors, rule_failures, rule_type_breakdown)
def main():
parser = argparse.ArgumentParser(description="Run OLMOCR Bench.")
parser.add_argument("--input_folder",
parser.add_argument(
"--input_folder",
default=os.path.join(os.path.dirname(__file__), "sample_data"),
help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.")
help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.",
)
args = parser.parse_args()
input_folder = args.input_folder
@ -268,7 +272,7 @@ def main():
print(f" Average Score: {overall_score * 100:.1f}% over {total_rules} rules.")
# Print final summary with breakdown by rule type
print("\n" + "="*50)
print("\n" + "=" * 50)
print("Final Summary:")
for candidate_name, overall_score, total_rules, candidate_errors, _, rule_type_breakdown in summary:
if candidate_errors:
@ -283,7 +287,8 @@ def main():
else:
avg = 0.0
print(f" {rtype:8s}: {avg:0.1f}% average pass rate over {len(scores)} rules")
print("="*50)
print("=" * 50)
if __name__ == "__main__":
main()

View File

@ -1,9 +1,11 @@
import argparse
import os
import glob
import importlib
import os
from tqdm import tqdm
def parse_method_arg(method_arg):
"""
Parse a method configuration string of the form:
@ -29,22 +31,11 @@ def parse_method_arg(method_arg):
raise ValueError(f"Extra argument '{extra}' is not in key=value format")
return name, kwargs
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run PDF conversion using specified OCR methods and extra parameters."
)
parser.add_argument(
"methods",
nargs="+",
help="Methods to run in the format method[:key=value ...]. "
"Example: gotocr mineru:temperature=2 marker:runs=3"
)
parser.add_argument(
"--repeats",
type=int,
default=1,
help="Number of times to repeat the conversion for each PDF."
)
parser = argparse.ArgumentParser(description="Run PDF conversion using specified OCR methods and extra parameters.")
parser.add_argument("methods", nargs="+", help="Methods to run in the format method[:key=value ...]. " "Example: gotocr mineru:temperature=2 marker:runs=3")
parser.add_argument("--repeats", type=int, default=1, help="Number of times to repeat the conversion for each PDF.")
args = parser.parse_args()
# Mapping of method names to a tuple: (module path, function name)
@ -60,16 +51,12 @@ if __name__ == "__main__":
for method_arg in args.methods:
method_name, extra_kwargs = parse_method_arg(method_arg)
if method_name not in available_methods:
parser.error(f"Unknown method: {method_name}. "
f"Available methods: {', '.join(available_methods.keys())}")
parser.error(f"Unknown method: {method_name}. " f"Available methods: {', '.join(available_methods.keys())}")
module_path, function_name = available_methods[method_name]
# Dynamically import the module and get the function.
module = importlib.import_module(module_path)
function = getattr(module, function_name)
config[method_name] = {
"method": function,
"kwargs": extra_kwargs
}
config[method_name] = {"method": function, "kwargs": extra_kwargs}
data_directory = os.path.join(os.path.dirname(__file__), "sample_data")
pdf_directory = os.path.join(data_directory, "pdfs")

View File

@ -1,16 +1,14 @@
# type: ignore
import os
import tempfile
import base64
import torch
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.data.anchor import get_anchor_text
from olmocr.data.prompts import build_openai_silver_data_prompt
from openai import OpenAI
from olmocr.data.anchor import get_anchor_text
from olmocr.data.prompts import build_openai_silver_data_prompt
from olmocr.data.renderpdf import render_pdf_to_base64png
def run_chatgpt(pdf_path: str, page_num: int=1, model: str='gpt-4o-2024-08-06') -> str:
def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-06") -> str:
"""
Convert page of a PDF file to markdown using GOT-OCR.
@ -24,19 +22,19 @@ def run_chatgpt(pdf_path: str, page_num: int=1, model: str='gpt-4o-2024-08-06')
str: The OCR result in markdown format.
"""
# Convert the first page of the PDF to a base64-encoded PNG image.
base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024)
# base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
response = client.chat.completions.create(
model=model,
messages= [
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
],
@ -44,10 +42,8 @@ def run_chatgpt(pdf_path: str, page_num: int=1, model: str='gpt-4o-2024-08-06')
max_tokens=3000,
logprobs=True,
top_logprobs=5,
response_format=openai_response_format_schema()
response_format=openai_response_format_schema(),
)
print(response)
return result

View File

@ -1,16 +1,18 @@
import base64
import os
import tempfile
import base64
import torch
from transformers import AutoModel, AutoTokenizer
from olmocr.data.renderpdf import render_pdf_to_base64png
from transformers import AutoModel, AutoTokenizer
# Global cache for the model and tokenizer.
_device = "cuda" if torch.cuda.is_available() else "cpu"
_model = None
_tokenizer = None
def load_model():
"""
Load the GOT-OCR model and tokenizer if they haven't been loaded already.
@ -20,21 +22,19 @@ def load_model():
"""
global _model, _tokenizer
if _model is None or _tokenizer is None:
_tokenizer = AutoTokenizer.from_pretrained(
'ucaslcl/GOT-OCR2_0', trust_remote_code=True
)
_tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
_model = AutoModel.from_pretrained(
'ucaslcl/GOT-OCR2_0',
"ucaslcl/GOT-OCR2_0",
trust_remote_code=True,
use_safetensors=True,
revision="979938bf89ccdc949c0131ddd3841e24578a4742",
pad_token_id=_tokenizer.eos_token_id
pad_token_id=_tokenizer.eos_token_id,
)
_model = _model.eval().to(_device)
return _model, _tokenizer
def run_gotocr(pdf_path: str, page_num: int=1, ocr_type: str='ocr') -> str:
def run_gotocr(pdf_path: str, page_num: int = 1, ocr_type: str = "ocr") -> str:
"""
Convert page of a PDF file to markdown using GOT-OCR.
@ -65,5 +65,3 @@ def run_gotocr(pdf_path: str, page_num: int=1, ocr_type: str='ocr') -> str:
os.remove(tmp_filename)
return result

View File

@ -1,15 +1,11 @@
import os
import time
import argparse
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
_marker_converter = None
def run_marker(pdf_path: str, page_num: int=1) -> str:
def run_marker(pdf_path: str, page_num: int = 1) -> str:
global _marker_converter
if _marker_converter is None:
@ -22,4 +18,3 @@ def run_marker(pdf_path: str, page_num: int=1) -> str:
text, _, images = text_from_rendered(rendered)
return text

View File

@ -1,15 +1,13 @@
import os
import shutil
import argparse
import tempfile
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import FileBasedDataReader, FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.config.enums import SupportedPdfParseMethod
def run_mineru(pdf_path: str, page_num: int=1) -> str:
def run_mineru(pdf_path: str, page_num: int = 1) -> str:
output_folder = tempfile.TemporaryDirectory()
image_output_folder = tempfile.TemporaryDirectory()
@ -34,7 +32,7 @@ def run_mineru(pdf_path: str, page_num: int=1) -> str:
# Generate markdown content; the image directory is the basename of the images output folder
image_dir_basename = os.path.basename(image_output_folder.name)
md_content = pipe_result.get_markdown(image_dir_basename)
# md_content = pipe_result.get_markdown(image_dir_basename)
# Dump markdown file
with tempfile.NamedTemporaryFile("w+", suffix="md") as tf:
@ -45,4 +43,3 @@ def run_mineru(pdf_path: str, page_num: int=1) -> str:
md_data = tf.read()
return md_data

View File

@ -1,9 +1,10 @@
import sys
import asyncio
import glob
import json
import os
import shutil
import asyncio
import sys
import olmocr.pipeline
# Set sys.argv as if you were running the script from the command line.
@ -13,7 +14,8 @@ workspace_dir = "olmocr/bench/sample_data/olmocr/workspace"
sys.argv = [
"pipeline.py", # The script name (can be arbitrary)
"olmocr/bench/sample_data/olmocr/workspace", # Positional argument: workspace
"--pdfs", *list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths
"--pdfs",
*list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths
]
# Call the async main() function.

View File

@ -100,7 +100,6 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool, pd
try:
with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile:
for line_number, line in enumerate(infile, 1):
line = line.strip()
if not line:

View File

@ -34,7 +34,6 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
try:
with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile:
for line_number, line in enumerate(infile, 1):
line = line.strip()
if not line:

View File

@ -168,5 +168,4 @@ class BenchmarkRepeatDetect(unittest.TestCase):
if __name__ == "__main__":
unittest.main()

View File

@ -1 +1,2 @@
class DolmaRefineError(RuntimeError): ...
class DolmaRefineError(RuntimeError):
...

View File

@ -5,10 +5,7 @@ import re
from typing import Optional
import boto3
from datasets import (
Dataset,
load_dataset,
)
from datasets import Dataset, load_dataset
from filelock import FileLock
from olmocr.data.renderpdf import get_pdf_media_box_width_height

View File

@ -4,17 +4,11 @@ from io import BytesIO
import torch
import torch.distributed
from PIL import Image
from transformers import (
AutoConfig,
AutoProcessor,
Qwen2_5_VLForConditionalGeneration,
)
from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import (
build_openai_silver_data_prompt,
)
from olmocr.prompts.prompts import build_openai_silver_data_prompt
@torch.no_grad()

View File

@ -9,11 +9,7 @@ import torchvision.transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import convert_image_dtype
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ImageInput,
)
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput
from transformers.processing_utils import ImagesKwargs
from transformers.utils import logging

View File

@ -740,7 +740,6 @@ class ViTMLP(nn.Module):
class ResidualAttentionBlock(nn.Module):
def __init__(self, config: FullMolmoConfig):
super().__init__()
self.config = config
@ -772,7 +771,6 @@ class ResidualAttentionBlock(nn.Module):
class BlockCollection(nn.Module):
def __init__(self, config: FullMolmoConfig):
super().__init__()
self.config = config
@ -801,7 +799,6 @@ class LayerNormFp32(nn.LayerNorm):
class VisionTransformer(nn.Module):
def __init__(self, config: FullMolmoConfig):
super().__init__()
self.config = config
@ -952,7 +949,6 @@ class MultiHeadDotProductAttention(nn.Module):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
if inputs_kv is not None:
inputs_k = inputs_kv
inputs_v = inputs_kv
@ -1099,7 +1095,6 @@ class MultiHeadAttentionPool(nn.Module):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor:
xk, xv = self.wk(inputs_kv), self.wv(inputs_kv)
if self.query == "mean":

View File

@ -16,11 +16,7 @@ import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.image_utils import ImageInput
from transformers.processing_utils import (
ProcessingKwargs,
ProcessorMixin,
TextKwargs,
)
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging

View File

@ -147,7 +147,6 @@ def run_train(config: TrainConfig):
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
with TemporaryDirectory() as output_dir:
training_args = TrainingArguments(
run_name=run_name.run,
logging_steps=config.hparams.log_every_steps,

View File

@ -224,7 +224,6 @@ class TruncatingCollator:
@contextmanager
def get_local_dir(output_dir: str):
with TemporaryDirectory() as tmp_dir:
if is_local(output_dir):
yield output_dir

View File

@ -362,7 +362,6 @@ async def run_vllm_async(
)
async with build_async_engine_client_from_engine_args(engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []

View File

@ -1,34 +1,37 @@
#!/usr/bin/env python3
import argparse
import boto3
import json
import random
import re
import requests
import time
import boto3
import requests
from tqdm import tqdm
from transformers import AutoTokenizer
# Allowed characters: alphanumeric, space, and basic punctuation ".,!?()"
ALLOWED_RE = re.compile(r'^[A-Za-z0-9\.,!?() ]+$')
ALLOWED_RE = re.compile(r"^[A-Za-z0-9\.,!?() ]+$")
def get_random_line_from_s3(bucket, key):
"""
Reads an S3 object line-by-line and returns a random line using reservoir sampling.
"""
s3 = boto3.client('s3')
s3 = boto3.client("s3")
response = s3.get_object(Bucket=bucket, Key=key)
random_line = None
count = 0
for line in response['Body'].iter_lines():
for line in response["Body"].iter_lines():
if not line:
continue
line_str = line.decode('utf-8')
line_str = line.decode("utf-8")
count += 1
if random.randint(1, count) == 1:
random_line = line_str
return random_line
def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3):
"""
Sends a count query to the infini-gram API for the given n-gram.
@ -47,10 +50,11 @@ def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3):
result = response.json()
if "count" in result:
return result["count"]
except Exception as e:
except Exception as e: # type: ignore
time.sleep(1)
return 0
def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llama_s4"):
"""
Tokenizes the document using the Llama2 tokenizer and samples random n-grams.
@ -67,7 +71,7 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam
# Get tokenized representation with offset mapping to determine word boundaries.
tokenized = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
token_ids = tokenized["input_ids"]
offsets = tokenized["offset_mapping"]
# offsets = tokenized["offset_mapping"]
if len(token_ids) < ngram_size:
return doc_id, 0, 0, []
@ -88,7 +92,7 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam
max_attempts = num_samples * 10 # Limit to prevent infinite loops.
while len(valid_ngram_details) < num_samples and attempts < max_attempts:
idx = random.choice(valid_positions)
ngram_token_ids = token_ids[idx: idx+ngram_size]
ngram_token_ids = token_ids[idx : idx + ngram_size]
ngram_str = tokenizer.decode(ngram_token_ids, clean_up_tokenization_spaces=True)
# Only accept n-grams that contain only allowed characters.
if ALLOWED_RE.fullmatch(ngram_str) and len(ngram_str.strip()) > ngram_size * 3:
@ -101,10 +105,9 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam
sample_count = len(valid_ngram_details)
return doc_id, match_count, sample_count, valid_ngram_details
def main():
parser = argparse.ArgumentParser(
description="Infini-gram n-gram matching script with Llama2 tokenization."
)
parser = argparse.ArgumentParser(description="Infini-gram n-gram matching script with Llama2 tokenization.")
parser.add_argument("N", type=int, help="Number of random .jsonl files to process")
parser.add_argument("s3_path", type=str, help="S3 path to a prefix containing .jsonl files (e.g., s3://my-bucket/my-prefix/)")
parser.add_argument("--index", type=str, default="v4_dolma-v1_7_llama", help="Infini-gram index to use (default: v4_rpj_llama_s4)")
@ -149,9 +152,7 @@ def main():
except Exception as e:
print(f"Error parsing JSON in {key}: {e}")
continue
doc_id, match_count, sample_count, details = process_document(
doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index
)
doc_id, match_count, sample_count, details = process_document(doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index)
# Print per-document n-gram summary
print(f"\nDocument ID: {doc_id}")
@ -167,5 +168,6 @@ def main():
overall_percentage = (total_matches / total_ngrams_sampled * 100) if total_ngrams_sampled else 0
print(f"\nTotal matched n-grams: {total_matches}/{total_ngrams_sampled} ({overall_percentage:.2f}%)")
if __name__ == "__main__":
main()

View File

@ -16,9 +16,7 @@ from tqdm import tqdm
from transformers import AutoProcessor
from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig
from olmocr.train.dataloader import (
build_finetuning_dataset,
)
from olmocr.train.dataloader import build_finetuning_dataset
from olmocr.train.dataprep import (
batch_prepare_data_for_molmo_training,
build_finetuning_prompt,
@ -223,7 +221,6 @@ class TestMolmoDataPrep(unittest.TestCase):
patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
):
# Set return values for the mocked functions
mock_get_anchor_text.return_value = "This is the anchor text."
# Create a red square image and encode it in base64
@ -305,7 +302,6 @@ class TestMolmoDataPrep(unittest.TestCase):
patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
):
mock_get_anchor_text.return_value = "This is the anchor text."
img = Image.new("RGB", (100, 100), color="red")
buffered = BytesIO()