fixed style

This commit is contained in:
aman-17 2025-02-25 08:57:02 -08:00
parent c2b54d8525
commit 0130a970c2
22 changed files with 113 additions and 162 deletions

View File

@ -12,14 +12,15 @@ The final score is averaged over the repeated generations.
""" """
import argparse import argparse
import os
import json
import glob import glob
import sys
import itertools import itertools
import json
import os
import sys
from rapidfuzz import fuzz
from fuzzysearch import find_near_matches from fuzzysearch import find_near_matches
from rapidfuzz import fuzz
def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]): def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]):
""" """
@ -30,7 +31,7 @@ def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]):
rules = [] rules = []
rule_ids = set() rule_ids = set()
with open(jsonl_path, 'r', encoding='utf-8') as f: with open(jsonl_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, start=1): for line_num, line in enumerate(f, start=1):
line = line.strip() line = line.strip()
if not line: if not line:
@ -75,6 +76,7 @@ def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]):
rules.append(data) rules.append(data)
return rules return rules
def run_rule(rule, md_file_path: str) -> (bool, str): def run_rule(rule, md_file_path: str) -> (bool, str):
""" """
Run the given rule on the content of the provided .md file. Run the given rule on the content of the provided .md file.
@ -82,7 +84,7 @@ def run_rule(rule, md_file_path: str) -> (bool, str):
and 'explanation' is a short message explaining the failure when the rule does not pass. and 'explanation' is a short message explaining the failure when the rule does not pass.
""" """
try: try:
with open(md_file_path, 'r', encoding='utf-8') as f: with open(md_file_path, "r", encoding="utf-8") as f:
md_content = f.read() md_content = f.read()
except Exception as e: except Exception as e:
return (False, f"Error reading {md_file_path}: {e}") return (False, f"Error reading {md_file_path}: {e}")
@ -121,6 +123,7 @@ def run_rule(rule, md_file_path: str) -> (bool, str):
else: else:
raise NotImplementedError(f"Rule type '{rule_type}' is not implemented.") raise NotImplementedError(f"Rule type '{rule_type}' is not implemented.")
def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: list[str]): def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: list[str]):
""" """
For the candidate folder (pipeline tool output), validate that it contains at least one .md file For the candidate folder (pipeline tool output), validate that it contains at least one .md file
@ -148,9 +151,7 @@ def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: li
md_pattern = os.path.join(candidate_folder, f"{md_base}_*.md") md_pattern = os.path.join(candidate_folder, f"{md_base}_*.md")
md_files = glob.glob(md_pattern) md_files = glob.glob(md_pattern)
if not md_files: if not md_files:
candidate_errors.append( candidate_errors.append(f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} (expected files matching {md_base}_*.md).")
f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} (expected files matching {md_base}_*.md)."
)
else: else:
pdf_to_md_files[pdf_name] = md_files pdf_to_md_files[pdf_name] = md_files
@ -195,11 +196,14 @@ def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: li
overall_score = total_rule_score / len(all_rules) if all_rules else 0.0 overall_score = total_rule_score / len(all_rules) if all_rules else 0.0
return (overall_score, len(all_rules), candidate_errors, rule_failures, rule_type_breakdown) return (overall_score, len(all_rules), candidate_errors, rule_failures, rule_type_breakdown)
def main(): def main():
parser = argparse.ArgumentParser(description="Run OLMOCR Bench.") parser = argparse.ArgumentParser(description="Run OLMOCR Bench.")
parser.add_argument("--input_folder", parser.add_argument(
default=os.path.join(os.path.dirname(__file__), "sample_data"), "--input_folder",
help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.") default=os.path.join(os.path.dirname(__file__), "sample_data"),
help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.",
)
args = parser.parse_args() args = parser.parse_args()
input_folder = args.input_folder input_folder = args.input_folder
@ -268,7 +272,7 @@ def main():
print(f" Average Score: {overall_score * 100:.1f}% over {total_rules} rules.") print(f" Average Score: {overall_score * 100:.1f}% over {total_rules} rules.")
# Print final summary with breakdown by rule type # Print final summary with breakdown by rule type
print("\n" + "="*50) print("\n" + "=" * 50)
print("Final Summary:") print("Final Summary:")
for candidate_name, overall_score, total_rules, candidate_errors, _, rule_type_breakdown in summary: for candidate_name, overall_score, total_rules, candidate_errors, _, rule_type_breakdown in summary:
if candidate_errors: if candidate_errors:
@ -283,7 +287,8 @@ def main():
else: else:
avg = 0.0 avg = 0.0
print(f" {rtype:8s}: {avg:0.1f}% average pass rate over {len(scores)} rules") print(f" {rtype:8s}: {avg:0.1f}% average pass rate over {len(scores)} rules")
print("="*50) print("=" * 50)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,9 +1,11 @@
import argparse import argparse
import os
import glob import glob
import importlib import importlib
import os
from tqdm import tqdm from tqdm import tqdm
def parse_method_arg(method_arg): def parse_method_arg(method_arg):
""" """
Parse a method configuration string of the form: Parse a method configuration string of the form:
@ -29,22 +31,11 @@ def parse_method_arg(method_arg):
raise ValueError(f"Extra argument '{extra}' is not in key=value format") raise ValueError(f"Extra argument '{extra}' is not in key=value format")
return name, kwargs return name, kwargs
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Run PDF conversion using specified OCR methods and extra parameters.")
description="Run PDF conversion using specified OCR methods and extra parameters." parser.add_argument("methods", nargs="+", help="Methods to run in the format method[:key=value ...]. " "Example: gotocr mineru:temperature=2 marker:runs=3")
) parser.add_argument("--repeats", type=int, default=1, help="Number of times to repeat the conversion for each PDF.")
parser.add_argument(
"methods",
nargs="+",
help="Methods to run in the format method[:key=value ...]. "
"Example: gotocr mineru:temperature=2 marker:runs=3"
)
parser.add_argument(
"--repeats",
type=int,
default=1,
help="Number of times to repeat the conversion for each PDF."
)
args = parser.parse_args() args = parser.parse_args()
# Mapping of method names to a tuple: (module path, function name) # Mapping of method names to a tuple: (module path, function name)
@ -60,16 +51,12 @@ if __name__ == "__main__":
for method_arg in args.methods: for method_arg in args.methods:
method_name, extra_kwargs = parse_method_arg(method_arg) method_name, extra_kwargs = parse_method_arg(method_arg)
if method_name not in available_methods: if method_name not in available_methods:
parser.error(f"Unknown method: {method_name}. " parser.error(f"Unknown method: {method_name}. " f"Available methods: {', '.join(available_methods.keys())}")
f"Available methods: {', '.join(available_methods.keys())}")
module_path, function_name = available_methods[method_name] module_path, function_name = available_methods[method_name]
# Dynamically import the module and get the function. # Dynamically import the module and get the function.
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
function = getattr(module, function_name) function = getattr(module, function_name)
config[method_name] = { config[method_name] = {"method": function, "kwargs": extra_kwargs}
"method": function,
"kwargs": extra_kwargs
}
data_directory = os.path.join(os.path.dirname(__file__), "sample_data") data_directory = os.path.join(os.path.dirname(__file__), "sample_data")
pdf_directory = os.path.join(data_directory, "pdfs") pdf_directory = os.path.join(data_directory, "pdfs")

View File

@ -1,16 +1,14 @@
# type: ignore
import os import os
import tempfile
import base64
import torch
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.data.anchor import get_anchor_text
from olmocr.data.prompts import build_openai_silver_data_prompt
from openai import OpenAI from openai import OpenAI
from olmocr.data.anchor import get_anchor_text
from olmocr.data.prompts import build_openai_silver_data_prompt
from olmocr.data.renderpdf import render_pdf_to_base64png
def run_chatgpt(pdf_path: str, page_num: int=1, model: str='gpt-4o-2024-08-06') -> str:
def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-06") -> str:
""" """
Convert page of a PDF file to markdown using GOT-OCR. Convert page of a PDF file to markdown using GOT-OCR.
@ -24,30 +22,28 @@ def run_chatgpt(pdf_path: str, page_num: int=1, model: str='gpt-4o-2024-08-06')
str: The OCR result in markdown format. str: The OCR result in markdown format.
""" """
# Convert the first page of the PDF to a base64-encoded PNG image. # Convert the first page of the PDF to a base64-encoded PNG image.
base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024) # base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport") anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
response = client.chat.completions.create( response = client.chat.completions.create(
model=model, model=model,
messages= [ messages=[
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)}, {"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}} {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
], ],
} }
], ],
temperature=0.1, temperature=0.1,
max_tokens=3000, max_tokens=3000,
logprobs=True, logprobs=True,
top_logprobs=5, top_logprobs=5,
response_format=openai_response_format_schema() response_format=openai_response_format_schema(),
) )
print(response) print(response)
return result return result

View File

@ -1,16 +1,18 @@
import base64
import os import os
import tempfile import tempfile
import base64
import torch import torch
from transformers import AutoModel, AutoTokenizer
from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.data.renderpdf import render_pdf_to_base64png
from transformers import AutoModel, AutoTokenizer
# Global cache for the model and tokenizer. # Global cache for the model and tokenizer.
_device = "cuda" if torch.cuda.is_available() else "cpu" _device = "cuda" if torch.cuda.is_available() else "cpu"
_model = None _model = None
_tokenizer = None _tokenizer = None
def load_model(): def load_model():
""" """
Load the GOT-OCR model and tokenizer if they haven't been loaded already. Load the GOT-OCR model and tokenizer if they haven't been loaded already.
@ -20,21 +22,19 @@ def load_model():
""" """
global _model, _tokenizer global _model, _tokenizer
if _model is None or _tokenizer is None: if _model is None or _tokenizer is None:
_tokenizer = AutoTokenizer.from_pretrained( _tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
'ucaslcl/GOT-OCR2_0', trust_remote_code=True
)
_model = AutoModel.from_pretrained( _model = AutoModel.from_pretrained(
'ucaslcl/GOT-OCR2_0', "ucaslcl/GOT-OCR2_0",
trust_remote_code=True, trust_remote_code=True,
use_safetensors=True, use_safetensors=True,
revision="979938bf89ccdc949c0131ddd3841e24578a4742", revision="979938bf89ccdc949c0131ddd3841e24578a4742",
pad_token_id=_tokenizer.eos_token_id pad_token_id=_tokenizer.eos_token_id,
) )
_model = _model.eval().to(_device) _model = _model.eval().to(_device)
return _model, _tokenizer return _model, _tokenizer
def run_gotocr(pdf_path: str, page_num: int=1, ocr_type: str='ocr') -> str: def run_gotocr(pdf_path: str, page_num: int = 1, ocr_type: str = "ocr") -> str:
""" """
Convert page of a PDF file to markdown using GOT-OCR. Convert page of a PDF file to markdown using GOT-OCR.
@ -65,5 +65,3 @@ def run_gotocr(pdf_path: str, page_num: int=1, ocr_type: str='ocr') -> str:
os.remove(tmp_filename) os.remove(tmp_filename)
return result return result

View File

@ -1,15 +1,11 @@
import os
import time
import argparse
from marker.converters.pdf import PdfConverter from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict from marker.models import create_model_dict
from marker.output import text_from_rendered from marker.output import text_from_rendered
_marker_converter = None _marker_converter = None
def run_marker(pdf_path: str, page_num: int=1) -> str:
def run_marker(pdf_path: str, page_num: int = 1) -> str:
global _marker_converter global _marker_converter
if _marker_converter is None: if _marker_converter is None:
@ -22,4 +18,3 @@ def run_marker(pdf_path: str, page_num: int=1) -> str:
text, _, images = text_from_rendered(rendered) text, _, images = text_from_rendered(rendered)
return text return text

View File

@ -1,15 +1,13 @@
import os import os
import shutil
import argparse
import tempfile import tempfile
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import FileBasedDataReader, FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.config.enums import SupportedPdfParseMethod
def run_mineru(pdf_path: str, page_num: int=1) -> str: def run_mineru(pdf_path: str, page_num: int = 1) -> str:
output_folder = tempfile.TemporaryDirectory() output_folder = tempfile.TemporaryDirectory()
image_output_folder = tempfile.TemporaryDirectory() image_output_folder = tempfile.TemporaryDirectory()
@ -34,7 +32,7 @@ def run_mineru(pdf_path: str, page_num: int=1) -> str:
# Generate markdown content; the image directory is the basename of the images output folder # Generate markdown content; the image directory is the basename of the images output folder
image_dir_basename = os.path.basename(image_output_folder.name) image_dir_basename = os.path.basename(image_output_folder.name)
md_content = pipe_result.get_markdown(image_dir_basename) # md_content = pipe_result.get_markdown(image_dir_basename)
# Dump markdown file # Dump markdown file
with tempfile.NamedTemporaryFile("w+", suffix="md") as tf: with tempfile.NamedTemporaryFile("w+", suffix="md") as tf:
@ -45,4 +43,3 @@ def run_mineru(pdf_path: str, page_num: int=1) -> str:
md_data = tf.read() md_data = tf.read()
return md_data return md_data

View File

@ -1,9 +1,10 @@
import sys import asyncio
import glob import glob
import json import json
import os import os
import shutil import shutil
import asyncio import sys
import olmocr.pipeline import olmocr.pipeline
# Set sys.argv as if you were running the script from the command line. # Set sys.argv as if you were running the script from the command line.
@ -11,9 +12,10 @@ import olmocr.pipeline
workspace_dir = "olmocr/bench/sample_data/olmocr/workspace" workspace_dir = "olmocr/bench/sample_data/olmocr/workspace"
sys.argv = [ sys.argv = [
"pipeline.py", # The script name (can be arbitrary) "pipeline.py", # The script name (can be arbitrary)
"olmocr/bench/sample_data/olmocr/workspace", # Positional argument: workspace "olmocr/bench/sample_data/olmocr/workspace", # Positional argument: workspace
"--pdfs", *list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths "--pdfs",
*list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths
] ]
# Call the async main() function. # Call the async main() function.

View File

@ -100,7 +100,6 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool, pd
try: try:
with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile: with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile:
for line_number, line in enumerate(infile, 1): for line_number, line in enumerate(infile, 1):
line = line.strip() line = line.strip()
if not line: if not line:

View File

@ -34,7 +34,6 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
try: try:
with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile: with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile:
for line_number, line in enumerate(infile, 1): for line_number, line in enumerate(infile, 1):
line = line.strip() line = line.strip()
if not line: if not line:

View File

@ -168,5 +168,4 @@ class BenchmarkRepeatDetect(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -1 +1,2 @@
class DolmaRefineError(RuntimeError): ... class DolmaRefineError(RuntimeError):
...

View File

@ -5,10 +5,7 @@ import re
from typing import Optional from typing import Optional
import boto3 import boto3
from datasets import ( from datasets import Dataset, load_dataset
Dataset,
load_dataset,
)
from filelock import FileLock from filelock import FileLock
from olmocr.data.renderpdf import get_pdf_media_box_width_height from olmocr.data.renderpdf import get_pdf_media_box_width_height

View File

@ -4,17 +4,11 @@ from io import BytesIO
import torch import torch
import torch.distributed import torch.distributed
from PIL import Image from PIL import Image
from transformers import ( from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration
AutoConfig,
AutoProcessor,
Qwen2_5_VLForConditionalGeneration,
)
from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import ( from olmocr.prompts.prompts import build_openai_silver_data_prompt
build_openai_silver_data_prompt,
)
@torch.no_grad() @torch.no_grad()

View File

@ -9,11 +9,7 @@ import torchvision.transforms
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import convert_image_dtype from torchvision.transforms.functional import convert_image_dtype
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import ( from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ImageInput,
)
from transformers.processing_utils import ImagesKwargs from transformers.processing_utils import ImagesKwargs
from transformers.utils import logging from transformers.utils import logging

View File

@ -740,7 +740,6 @@ class ViTMLP(nn.Module):
class ResidualAttentionBlock(nn.Module): class ResidualAttentionBlock(nn.Module):
def __init__(self, config: FullMolmoConfig): def __init__(self, config: FullMolmoConfig):
super().__init__() super().__init__()
self.config = config self.config = config
@ -772,7 +771,6 @@ class ResidualAttentionBlock(nn.Module):
class BlockCollection(nn.Module): class BlockCollection(nn.Module):
def __init__(self, config: FullMolmoConfig): def __init__(self, config: FullMolmoConfig):
super().__init__() super().__init__()
self.config = config self.config = config
@ -801,7 +799,6 @@ class LayerNormFp32(nn.LayerNorm):
class VisionTransformer(nn.Module): class VisionTransformer(nn.Module):
def __init__(self, config: FullMolmoConfig): def __init__(self, config: FullMolmoConfig):
super().__init__() super().__init__()
self.config = config self.config = config
@ -952,7 +949,6 @@ class MultiHeadDotProductAttention(nn.Module):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
if inputs_kv is not None: if inputs_kv is not None:
inputs_k = inputs_kv inputs_k = inputs_kv
inputs_v = inputs_kv inputs_v = inputs_kv
@ -1099,7 +1095,6 @@ class MultiHeadAttentionPool(nn.Module):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor: def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor:
xk, xv = self.wk(inputs_kv), self.wv(inputs_kv) xk, xv = self.wk(inputs_kv), self.wv(inputs_kv)
if self.query == "mean": if self.query == "mean":

View File

@ -16,11 +16,7 @@ import numpy as np
import torch import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.processing_utils import ( from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
ProcessingKwargs,
ProcessorMixin,
TextKwargs,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging from transformers.utils import logging

View File

@ -147,7 +147,6 @@ def run_train(config: TrainConfig):
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
with TemporaryDirectory() as output_dir: with TemporaryDirectory() as output_dir:
training_args = TrainingArguments( training_args = TrainingArguments(
run_name=run_name.run, run_name=run_name.run,
logging_steps=config.hparams.log_every_steps, logging_steps=config.hparams.log_every_steps,

View File

@ -224,7 +224,6 @@ class TruncatingCollator:
@contextmanager @contextmanager
def get_local_dir(output_dir: str): def get_local_dir(output_dir: str):
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
if is_local(output_dir): if is_local(output_dir):
yield output_dir yield output_dir

View File

@ -362,7 +362,6 @@ async def run_vllm_async(
) )
async with build_async_engine_client_from_engine_args(engine_args, disable_frontend_multiprocessing) as llm: async with build_async_engine_client_from_engine_args(engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine. # Add the requests to the engine.
prompts: List[str] = [] prompts: List[str] = []
sampling_params: List[SamplingParams] = [] sampling_params: List[SamplingParams] = []

View File

@ -1,34 +1,37 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import boto3
import json import json
import random import random
import re import re
import requests
import time import time
import boto3
import requests
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
# Allowed characters: alphanumeric, space, and basic punctuation ".,!?()" # Allowed characters: alphanumeric, space, and basic punctuation ".,!?()"
ALLOWED_RE = re.compile(r'^[A-Za-z0-9\.,!?() ]+$') ALLOWED_RE = re.compile(r"^[A-Za-z0-9\.,!?() ]+$")
def get_random_line_from_s3(bucket, key): def get_random_line_from_s3(bucket, key):
""" """
Reads an S3 object line-by-line and returns a random line using reservoir sampling. Reads an S3 object line-by-line and returns a random line using reservoir sampling.
""" """
s3 = boto3.client('s3') s3 = boto3.client("s3")
response = s3.get_object(Bucket=bucket, Key=key) response = s3.get_object(Bucket=bucket, Key=key)
random_line = None random_line = None
count = 0 count = 0
for line in response['Body'].iter_lines(): for line in response["Body"].iter_lines():
if not line: if not line:
continue continue
line_str = line.decode('utf-8') line_str = line.decode("utf-8")
count += 1 count += 1
if random.randint(1, count) == 1: if random.randint(1, count) == 1:
random_line = line_str random_line = line_str
return random_line return random_line
def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3): def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3):
""" """
Sends a count query to the infini-gram API for the given n-gram. Sends a count query to the infini-gram API for the given n-gram.
@ -47,10 +50,11 @@ def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3):
result = response.json() result = response.json()
if "count" in result: if "count" in result:
return result["count"] return result["count"]
except Exception as e: except Exception as e: # type: ignore
time.sleep(1) time.sleep(1)
return 0 return 0
def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llama_s4"): def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llama_s4"):
""" """
Tokenizes the document using the Llama2 tokenizer and samples random n-grams. Tokenizes the document using the Llama2 tokenizer and samples random n-grams.
@ -67,7 +71,7 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam
# Get tokenized representation with offset mapping to determine word boundaries. # Get tokenized representation with offset mapping to determine word boundaries.
tokenized = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True) tokenized = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
token_ids = tokenized["input_ids"] token_ids = tokenized["input_ids"]
offsets = tokenized["offset_mapping"] # offsets = tokenized["offset_mapping"]
if len(token_ids) < ngram_size: if len(token_ids) < ngram_size:
return doc_id, 0, 0, [] return doc_id, 0, 0, []
@ -88,7 +92,7 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam
max_attempts = num_samples * 10 # Limit to prevent infinite loops. max_attempts = num_samples * 10 # Limit to prevent infinite loops.
while len(valid_ngram_details) < num_samples and attempts < max_attempts: while len(valid_ngram_details) < num_samples and attempts < max_attempts:
idx = random.choice(valid_positions) idx = random.choice(valid_positions)
ngram_token_ids = token_ids[idx: idx+ngram_size] ngram_token_ids = token_ids[idx : idx + ngram_size]
ngram_str = tokenizer.decode(ngram_token_ids, clean_up_tokenization_spaces=True) ngram_str = tokenizer.decode(ngram_token_ids, clean_up_tokenization_spaces=True)
# Only accept n-grams that contain only allowed characters. # Only accept n-grams that contain only allowed characters.
if ALLOWED_RE.fullmatch(ngram_str) and len(ngram_str.strip()) > ngram_size * 3: if ALLOWED_RE.fullmatch(ngram_str) and len(ngram_str.strip()) > ngram_size * 3:
@ -101,10 +105,9 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam
sample_count = len(valid_ngram_details) sample_count = len(valid_ngram_details)
return doc_id, match_count, sample_count, valid_ngram_details return doc_id, match_count, sample_count, valid_ngram_details
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Infini-gram n-gram matching script with Llama2 tokenization.")
description="Infini-gram n-gram matching script with Llama2 tokenization."
)
parser.add_argument("N", type=int, help="Number of random .jsonl files to process") parser.add_argument("N", type=int, help="Number of random .jsonl files to process")
parser.add_argument("s3_path", type=str, help="S3 path to a prefix containing .jsonl files (e.g., s3://my-bucket/my-prefix/)") parser.add_argument("s3_path", type=str, help="S3 path to a prefix containing .jsonl files (e.g., s3://my-bucket/my-prefix/)")
parser.add_argument("--index", type=str, default="v4_dolma-v1_7_llama", help="Infini-gram index to use (default: v4_rpj_llama_s4)") parser.add_argument("--index", type=str, default="v4_dolma-v1_7_llama", help="Infini-gram index to use (default: v4_rpj_llama_s4)")
@ -149,9 +152,7 @@ def main():
except Exception as e: except Exception as e:
print(f"Error parsing JSON in {key}: {e}") print(f"Error parsing JSON in {key}: {e}")
continue continue
doc_id, match_count, sample_count, details = process_document( doc_id, match_count, sample_count, details = process_document(doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index)
doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index
)
# Print per-document n-gram summary # Print per-document n-gram summary
print(f"\nDocument ID: {doc_id}") print(f"\nDocument ID: {doc_id}")
@ -167,5 +168,6 @@ def main():
overall_percentage = (total_matches / total_ngrams_sampled * 100) if total_ngrams_sampled else 0 overall_percentage = (total_matches / total_ngrams_sampled * 100) if total_ngrams_sampled else 0
print(f"\nTotal matched n-grams: {total_matches}/{total_ngrams_sampled} ({overall_percentage:.2f}%)") print(f"\nTotal matched n-grams: {total_matches}/{total_ngrams_sampled} ({overall_percentage:.2f}%)")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -16,9 +16,7 @@ from tqdm import tqdm
from transformers import AutoProcessor from transformers import AutoProcessor
from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig
from olmocr.train.dataloader import ( from olmocr.train.dataloader import build_finetuning_dataset
build_finetuning_dataset,
)
from olmocr.train.dataprep import ( from olmocr.train.dataprep import (
batch_prepare_data_for_molmo_training, batch_prepare_data_for_molmo_training,
build_finetuning_prompt, build_finetuning_prompt,
@ -223,7 +221,6 @@ class TestMolmoDataPrep(unittest.TestCase):
patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text, patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png, patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
): ):
# Set return values for the mocked functions # Set return values for the mocked functions
mock_get_anchor_text.return_value = "This is the anchor text." mock_get_anchor_text.return_value = "This is the anchor text."
# Create a red square image and encode it in base64 # Create a red square image and encode it in base64
@ -305,7 +302,6 @@ class TestMolmoDataPrep(unittest.TestCase):
patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text, patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png, patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
): ):
mock_get_anchor_text.return_value = "This is the anchor text." mock_get_anchor_text.return_value = "This is the anchor text."
img = Image.new("RGB", (100, 100), color="red") img = Image.new("RGB", (100, 100), color="red")
buffered = BytesIO() buffered = BytesIO()