diff --git a/olmocr/bench/benchmark.py b/olmocr/bench/benchmark.py index d75ec00..2703299 100644 --- a/olmocr/bench/benchmark.py +++ b/olmocr/bench/benchmark.py @@ -12,14 +12,15 @@ The final score is averaged over the repeated generations. """ import argparse -import os -import json import glob -import sys import itertools +import json +import os +import sys -from rapidfuzz import fuzz from fuzzysearch import find_near_matches +from rapidfuzz import fuzz + def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]): """ @@ -30,7 +31,7 @@ def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]): rules = [] rule_ids = set() - with open(jsonl_path, 'r', encoding='utf-8') as f: + with open(jsonl_path, "r", encoding="utf-8") as f: for line_num, line in enumerate(f, start=1): line = line.strip() if not line: @@ -75,6 +76,7 @@ def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]): rules.append(data) return rules + def run_rule(rule, md_file_path: str) -> (bool, str): """ Run the given rule on the content of the provided .md file. @@ -82,7 +84,7 @@ def run_rule(rule, md_file_path: str) -> (bool, str): and 'explanation' is a short message explaining the failure when the rule does not pass. """ try: - with open(md_file_path, 'r', encoding='utf-8') as f: + with open(md_file_path, "r", encoding="utf-8") as f: md_content = f.read() except Exception as e: return (False, f"Error reading {md_file_path}: {e}") @@ -121,15 +123,16 @@ def run_rule(rule, md_file_path: str) -> (bool, str): else: raise NotImplementedError(f"Rule type '{rule_type}' is not implemented.") + def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: list[str]): """ For the candidate folder (pipeline tool output), validate that it contains at least one .md file (i.e. repeated generations like _1.md, _2.md, etc.) for every PDF in the pdf folder. Then, run each rule against all corresponding .md files and average the results. - + Returns a tuple: (overall_score, total_rules, candidate_errors, rule_failures, rule_type_breakdown) - + - overall_score: Average fraction of rules passed (averaged over repeats and rules). - total_rules: Total number of rules evaluated. - candidate_errors: List of candidate errors (e.g. missing files). @@ -148,9 +151,7 @@ def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: li md_pattern = os.path.join(candidate_folder, f"{md_base}_*.md") md_files = glob.glob(md_pattern) if not md_files: - candidate_errors.append( - f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} (expected files matching {md_base}_*.md)." - ) + candidate_errors.append(f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} (expected files matching {md_base}_*.md).") else: pdf_to_md_files[pdf_name] = md_files @@ -195,11 +196,14 @@ def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: li overall_score = total_rule_score / len(all_rules) if all_rules else 0.0 return (overall_score, len(all_rules), candidate_errors, rule_failures, rule_type_breakdown) + def main(): parser = argparse.ArgumentParser(description="Run OLMOCR Bench.") - parser.add_argument("--input_folder", - default=os.path.join(os.path.dirname(__file__), "sample_data"), - help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.") + parser.add_argument( + "--input_folder", + default=os.path.join(os.path.dirname(__file__), "sample_data"), + help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.", + ) args = parser.parse_args() input_folder = args.input_folder @@ -268,7 +272,7 @@ def main(): print(f" Average Score: {overall_score * 100:.1f}% over {total_rules} rules.") # Print final summary with breakdown by rule type - print("\n" + "="*50) + print("\n" + "=" * 50) print("Final Summary:") for candidate_name, overall_score, total_rules, candidate_errors, _, rule_type_breakdown in summary: if candidate_errors: @@ -283,7 +287,8 @@ def main(): else: avg = 0.0 print(f" {rtype:8s}: {avg:0.1f}% average pass rate over {len(scores)} rules") - print("="*50) + print("=" * 50) + if __name__ == "__main__": main() diff --git a/olmocr/bench/convert.py b/olmocr/bench/convert.py index 90cb77a..ccf9fbc 100644 --- a/olmocr/bench/convert.py +++ b/olmocr/bench/convert.py @@ -1,9 +1,11 @@ import argparse -import os import glob import importlib +import os + from tqdm import tqdm + def parse_method_arg(method_arg): """ Parse a method configuration string of the form: @@ -29,22 +31,11 @@ def parse_method_arg(method_arg): raise ValueError(f"Extra argument '{extra}' is not in key=value format") return name, kwargs + if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Run PDF conversion using specified OCR methods and extra parameters." - ) - parser.add_argument( - "methods", - nargs="+", - help="Methods to run in the format method[:key=value ...]. " - "Example: gotocr mineru:temperature=2 marker:runs=3" - ) - parser.add_argument( - "--repeats", - type=int, - default=1, - help="Number of times to repeat the conversion for each PDF." - ) + parser = argparse.ArgumentParser(description="Run PDF conversion using specified OCR methods and extra parameters.") + parser.add_argument("methods", nargs="+", help="Methods to run in the format method[:key=value ...]. " "Example: gotocr mineru:temperature=2 marker:runs=3") + parser.add_argument("--repeats", type=int, default=1, help="Number of times to repeat the conversion for each PDF.") args = parser.parse_args() # Mapping of method names to a tuple: (module path, function name) @@ -60,16 +51,12 @@ if __name__ == "__main__": for method_arg in args.methods: method_name, extra_kwargs = parse_method_arg(method_arg) if method_name not in available_methods: - parser.error(f"Unknown method: {method_name}. " - f"Available methods: {', '.join(available_methods.keys())}") + parser.error(f"Unknown method: {method_name}. " f"Available methods: {', '.join(available_methods.keys())}") module_path, function_name = available_methods[method_name] # Dynamically import the module and get the function. module = importlib.import_module(module_path) function = getattr(module, function_name) - config[method_name] = { - "method": function, - "kwargs": extra_kwargs - } + config[method_name] = {"method": function, "kwargs": extra_kwargs} data_directory = os.path.join(os.path.dirname(__file__), "sample_data") pdf_directory = os.path.join(data_directory, "pdfs") diff --git a/olmocr/bench/miners/mine_big_diff.py b/olmocr/bench/miners/mine_big_diff.py index a6b9298..d748e3a 100644 --- a/olmocr/bench/miners/mine_big_diff.py +++ b/olmocr/bench/miners/mine_big_diff.py @@ -6,4 +6,4 @@ # Then, prompt again to generate the set of absent/present rules, given the diffs presented -# Then, run those rules through a tinyhost verification/edit system to quickly build up a big set \ No newline at end of file +# Then, run those rules through a tinyhost verification/edit system to quickly build up a big set diff --git a/olmocr/bench/runners/run_chatgpt.py b/olmocr/bench/runners/run_chatgpt.py index da39e72..5f9e662 100644 --- a/olmocr/bench/runners/run_chatgpt.py +++ b/olmocr/bench/runners/run_chatgpt.py @@ -1,53 +1,49 @@ +# type: ignore import os -import tempfile -import base64 -import torch - -from olmocr.data.renderpdf import render_pdf_to_base64png -from olmocr.data.anchor import get_anchor_text -from olmocr.data.prompts import build_openai_silver_data_prompt from openai import OpenAI - -def run_chatgpt(pdf_path: str, page_num: int=1, model: str='gpt-4o-2024-08-06') -> str: +from olmocr.data.anchor import get_anchor_text +from olmocr.data.prompts import build_openai_silver_data_prompt +from olmocr.data.renderpdf import render_pdf_to_base64png + + +def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-06") -> str: """ Convert page of a PDF file to markdown using GOT-OCR. - + This function renders the first page of the PDF to an image, runs OCR on that image, and returns the OCR result as a markdown-formatted string. - + Args: pdf_path (str): The local path to the PDF file. - + Returns: str: The OCR result in markdown format. """ # Convert the first page of the PDF to a base64-encoded PNG image. - base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024) + # base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024) anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport") client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) response = client.chat.completions.create( model=model, - messages= [ - { - "role": "user", - "content": [ - {"type": "text", "text": build_openai_silver_data_prompt(anchor_text)}, - {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}} - ], - } - ], + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": build_openai_silver_data_prompt(anchor_text)}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}, + ], + } + ], temperature=0.1, max_tokens=3000, logprobs=True, top_logprobs=5, - response_format=openai_response_format_schema() + response_format=openai_response_format_schema(), ) print(response) - + return result - - diff --git a/olmocr/bench/runners/run_gotocr.py b/olmocr/bench/runners/run_gotocr.py index a95aa92..8cae98e 100644 --- a/olmocr/bench/runners/run_gotocr.py +++ b/olmocr/bench/runners/run_gotocr.py @@ -1,16 +1,18 @@ +import base64 import os import tempfile -import base64 + import torch +from transformers import AutoModel, AutoTokenizer from olmocr.data.renderpdf import render_pdf_to_base64png -from transformers import AutoModel, AutoTokenizer # Global cache for the model and tokenizer. _device = "cuda" if torch.cuda.is_available() else "cpu" _model = None _tokenizer = None + def load_model(): """ Load the GOT-OCR model and tokenizer if they haven't been loaded already. @@ -20,50 +22,46 @@ def load_model(): """ global _model, _tokenizer if _model is None or _tokenizer is None: - _tokenizer = AutoTokenizer.from_pretrained( - 'ucaslcl/GOT-OCR2_0', trust_remote_code=True - ) + _tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True) _model = AutoModel.from_pretrained( - 'ucaslcl/GOT-OCR2_0', + "ucaslcl/GOT-OCR2_0", trust_remote_code=True, use_safetensors=True, revision="979938bf89ccdc949c0131ddd3841e24578a4742", - pad_token_id=_tokenizer.eos_token_id + pad_token_id=_tokenizer.eos_token_id, ) _model = _model.eval().to(_device) return _model, _tokenizer -def run_gotocr(pdf_path: str, page_num: int=1, ocr_type: str='ocr') -> str: +def run_gotocr(pdf_path: str, page_num: int = 1, ocr_type: str = "ocr") -> str: """ Convert page of a PDF file to markdown using GOT-OCR. - + This function renders the first page of the PDF to an image, runs OCR on that image, and returns the OCR result as a markdown-formatted string. - + Args: pdf_path (str): The local path to the PDF file. - + Returns: str: The OCR result in markdown format. """ # Ensure the model is loaded (cached across calls) model, tokenizer = load_model() - + # Convert the first page of the PDF to a base64-encoded PNG image. base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024) - + # Write the image to a temporary file. with tempfile.NamedTemporaryFile("wb", suffix=".png", delete=False) as tmp: tmp.write(base64.b64decode(base64image)) tmp_filename = tmp.name - + # Run GOT-OCR on the saved image. result = model.chat(tokenizer, tmp_filename, ocr_type=ocr_type) - + # Clean up the temporary file. os.remove(tmp_filename) - + return result - - diff --git a/olmocr/bench/runners/run_marker.py b/olmocr/bench/runners/run_marker.py index 68461a5..920963a 100644 --- a/olmocr/bench/runners/run_marker.py +++ b/olmocr/bench/runners/run_marker.py @@ -1,15 +1,11 @@ -import os -import time -import argparse - - from marker.converters.pdf import PdfConverter from marker.models import create_model_dict from marker.output import text_from_rendered _marker_converter = None -def run_marker(pdf_path: str, page_num: int=1) -> str: + +def run_marker(pdf_path: str, page_num: int = 1) -> str: global _marker_converter if _marker_converter is None: @@ -22,4 +18,3 @@ def run_marker(pdf_path: str, page_num: int=1) -> str: text, _, images = text_from_rendered(rendered) return text - diff --git a/olmocr/bench/runners/run_mineru.py b/olmocr/bench/runners/run_mineru.py index 8ca5447..6dc1d5e 100644 --- a/olmocr/bench/runners/run_mineru.py +++ b/olmocr/bench/runners/run_mineru.py @@ -1,15 +1,13 @@ import os -import shutil -import argparse import tempfile -from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader +from magic_pdf.config.enums import SupportedPdfParseMethod +from magic_pdf.data.data_reader_writer import FileBasedDataReader, FileBasedDataWriter from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze -from magic_pdf.config.enums import SupportedPdfParseMethod -def run_mineru(pdf_path: str, page_num: int=1) -> str: +def run_mineru(pdf_path: str, page_num: int = 1) -> str: output_folder = tempfile.TemporaryDirectory() image_output_folder = tempfile.TemporaryDirectory() @@ -34,7 +32,7 @@ def run_mineru(pdf_path: str, page_num: int=1) -> str: # Generate markdown content; the image directory is the basename of the images output folder image_dir_basename = os.path.basename(image_output_folder.name) - md_content = pipe_result.get_markdown(image_dir_basename) + # md_content = pipe_result.get_markdown(image_dir_basename) # Dump markdown file with tempfile.NamedTemporaryFile("w+", suffix="md") as tf: @@ -45,4 +43,3 @@ def run_mineru(pdf_path: str, page_num: int=1) -> str: md_data = tf.read() return md_data - diff --git a/olmocr/bench/runners/run_olmocr.py b/olmocr/bench/runners/run_olmocr.py index 35b257e..ec39c86 100644 --- a/olmocr/bench/runners/run_olmocr.py +++ b/olmocr/bench/runners/run_olmocr.py @@ -1,9 +1,10 @@ -import sys +import asyncio import glob import json import os import shutil -import asyncio +import sys + import olmocr.pipeline # Set sys.argv as if you were running the script from the command line. @@ -11,9 +12,10 @@ import olmocr.pipeline workspace_dir = "olmocr/bench/sample_data/olmocr/workspace" sys.argv = [ - "pipeline.py", # The script name (can be arbitrary) - "olmocr/bench/sample_data/olmocr/workspace", # Positional argument: workspace - "--pdfs", *list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths + "pipeline.py", # The script name (can be arbitrary) + "olmocr/bench/sample_data/olmocr/workspace", # Positional argument: workspace + "--pdfs", + *list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths ] # Call the async main() function. @@ -30,4 +32,4 @@ for jsonl_path in glob.glob(workspace_dir + "/results/*.jsonl"): with open(f"olmocr/bench/sample_data/olmocr/{name.replace('.pdf', '.md')}", "w") as out_f: out_f.write(data["text"]) -shutil.rmtree(workspace_dir) \ No newline at end of file +shutil.rmtree(workspace_dir) diff --git a/olmocr/data/convertsilver_birr.py b/olmocr/data/convertsilver_birr.py index b6760b7..b3ee178 100644 --- a/olmocr/data/convertsilver_birr.py +++ b/olmocr/data/convertsilver_birr.py @@ -100,7 +100,6 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool, pd try: with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile: - for line_number, line in enumerate(infile, 1): line = line.strip() if not line: diff --git a/olmocr/data/convertsilver_openai.py b/olmocr/data/convertsilver_openai.py index ab89d64..055fec4 100644 --- a/olmocr/data/convertsilver_openai.py +++ b/olmocr/data/convertsilver_openai.py @@ -34,7 +34,6 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool): try: with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile: - for line_number, line in enumerate(infile, 1): line = line.strip() if not line: diff --git a/olmocr/repeatdetect.py b/olmocr/repeatdetect.py index fb42d27..76166e4 100644 --- a/olmocr/repeatdetect.py +++ b/olmocr/repeatdetect.py @@ -168,5 +168,4 @@ class BenchmarkRepeatDetect(unittest.TestCase): if __name__ == "__main__": - unittest.main() diff --git a/olmocr/train/core/errors.py b/olmocr/train/core/errors.py index a24dbe0..afe3e4c 100644 --- a/olmocr/train/core/errors.py +++ b/olmocr/train/core/errors.py @@ -1 +1,2 @@ -class DolmaRefineError(RuntimeError): ... +class DolmaRefineError(RuntimeError): + ... diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index 7304399..320aff2 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -5,10 +5,7 @@ import re from typing import Optional import boto3 -from datasets import ( - Dataset, - load_dataset, -) +from datasets import Dataset, load_dataset from filelock import FileLock from olmocr.data.renderpdf import get_pdf_media_box_width_height diff --git a/olmocr/train/inference.py b/olmocr/train/inference.py index e03ec7c..d28bffa 100644 --- a/olmocr/train/inference.py +++ b/olmocr/train/inference.py @@ -4,17 +4,11 @@ from io import BytesIO import torch import torch.distributed from PIL import Image -from transformers import ( - AutoConfig, - AutoProcessor, - Qwen2_5_VLForConditionalGeneration, -) +from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.prompts.anchor import get_anchor_text -from olmocr.prompts.prompts import ( - build_openai_silver_data_prompt, -) +from olmocr.prompts.prompts import build_openai_silver_data_prompt @torch.no_grad() diff --git a/olmocr/train/molmo/image_processing_molmo.py b/olmocr/train/molmo/image_processing_molmo.py index f3c20f1..ba68435 100644 --- a/olmocr/train/molmo/image_processing_molmo.py +++ b/olmocr/train/molmo/image_processing_molmo.py @@ -9,11 +9,7 @@ import torchvision.transforms from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import convert_image_dtype from transformers.image_processing_utils import BaseImageProcessor -from transformers.image_utils import ( - OPENAI_CLIP_MEAN, - OPENAI_CLIP_STD, - ImageInput, -) +from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput from transformers.processing_utils import ImagesKwargs from transformers.utils import logging diff --git a/olmocr/train/molmo/modeling_molmo.py b/olmocr/train/molmo/modeling_molmo.py index e4d8460..da274e4 100644 --- a/olmocr/train/molmo/modeling_molmo.py +++ b/olmocr/train/molmo/modeling_molmo.py @@ -740,7 +740,6 @@ class ViTMLP(nn.Module): class ResidualAttentionBlock(nn.Module): - def __init__(self, config: FullMolmoConfig): super().__init__() self.config = config @@ -772,7 +771,6 @@ class ResidualAttentionBlock(nn.Module): class BlockCollection(nn.Module): - def __init__(self, config: FullMolmoConfig): super().__init__() self.config = config @@ -801,7 +799,6 @@ class LayerNormFp32(nn.LayerNorm): class VisionTransformer(nn.Module): - def __init__(self, config: FullMolmoConfig): super().__init__() self.config = config @@ -952,7 +949,6 @@ class MultiHeadDotProductAttention(nn.Module): return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: - if inputs_kv is not None: inputs_k = inputs_kv inputs_v = inputs_kv @@ -1099,7 +1095,6 @@ class MultiHeadAttentionPool(nn.Module): return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor: - xk, xv = self.wk(inputs_kv), self.wv(inputs_kv) if self.query == "mean": diff --git a/olmocr/train/molmo/preprocessing_molmo.py b/olmocr/train/molmo/preprocessing_molmo.py index 3c9dabb..a7c63cc 100644 --- a/olmocr/train/molmo/preprocessing_molmo.py +++ b/olmocr/train/molmo/preprocessing_molmo.py @@ -16,11 +16,7 @@ import numpy as np import torch from transformers import AutoTokenizer from transformers.image_utils import ImageInput -from transformers.processing_utils import ( - ProcessingKwargs, - ProcessorMixin, - TextKwargs, -) +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 19f96af..eb40449 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -147,7 +147,6 @@ def run_train(config: TrainConfig): save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore with TemporaryDirectory() as output_dir: - training_args = TrainingArguments( run_name=run_name.run, logging_steps=config.hparams.log_every_steps, diff --git a/olmocr/train/utils.py b/olmocr/train/utils.py index f5867d4..0ab59da 100644 --- a/olmocr/train/utils.py +++ b/olmocr/train/utils.py @@ -224,7 +224,6 @@ class TruncatingCollator: @contextmanager def get_local_dir(output_dir: str): - with TemporaryDirectory() as tmp_dir: if is_local(output_dir): yield output_dir diff --git a/scripts/benchmark_throughput.py b/scripts/benchmark_throughput.py index 8c11ecf..ad1adee 100644 --- a/scripts/benchmark_throughput.py +++ b/scripts/benchmark_throughput.py @@ -362,7 +362,6 @@ async def run_vllm_async( ) async with build_async_engine_client_from_engine_args(engine_args, disable_frontend_multiprocessing) as llm: - # Add the requests to the engine. prompts: List[str] = [] sampling_params: List[SamplingParams] = [] diff --git a/scripts/infinigram_count.py b/scripts/infinigram_count.py index 7225220..10efe07 100644 --- a/scripts/infinigram_count.py +++ b/scripts/infinigram_count.py @@ -1,34 +1,37 @@ #!/usr/bin/env python3 import argparse -import boto3 import json import random import re -import requests import time + +import boto3 +import requests from tqdm import tqdm from transformers import AutoTokenizer # Allowed characters: alphanumeric, space, and basic punctuation ".,!?()" -ALLOWED_RE = re.compile(r'^[A-Za-z0-9\.,!?() ]+$') +ALLOWED_RE = re.compile(r"^[A-Za-z0-9\.,!?() ]+$") + def get_random_line_from_s3(bucket, key): """ Reads an S3 object line-by-line and returns a random line using reservoir sampling. """ - s3 = boto3.client('s3') + s3 = boto3.client("s3") response = s3.get_object(Bucket=bucket, Key=key) random_line = None count = 0 - for line in response['Body'].iter_lines(): + for line in response["Body"].iter_lines(): if not line: continue - line_str = line.decode('utf-8') + line_str = line.decode("utf-8") count += 1 if random.randint(1, count) == 1: random_line = line_str return random_line + def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3): """ Sends a count query to the infini-gram API for the given n-gram. @@ -47,17 +50,18 @@ def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3): result = response.json() if "count" in result: return result["count"] - except Exception as e: + except Exception as e: # type: ignore time.sleep(1) return 0 + def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llama_s4"): """ Tokenizes the document using the Llama2 tokenizer and samples random n-grams. Each n-gram is chosen such that: 1. It starts on a word-split boundary (using the offset mapping and a check on the preceding character). 2. Its decoded string contains only alphanumeric characters, spaces, and the punctuation marks ".,!?()". - + Each valid n-gram is then queried using the infini-gram API. The function returns the document id, the number of matching n-grams (i.e. API count > 0), the total number of valid n-grams sampled, and a list of tuples (flag, ngram_string). @@ -67,7 +71,7 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam # Get tokenized representation with offset mapping to determine word boundaries. tokenized = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True) token_ids = tokenized["input_ids"] - offsets = tokenized["offset_mapping"] + # offsets = tokenized["offset_mapping"] if len(token_ids) < ngram_size: return doc_id, 0, 0, [] @@ -78,17 +82,17 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam # start_offset = offsets[i][0] # if start_offset == 0 or (start_offset > 0 and text[start_offset - 1] == " "): # valid_positions.append(i) - + if not valid_positions: # Fallback: if no valid positions are found, use all possible positions. valid_positions = list(range(len(token_ids) - ngram_size + 1)) - + valid_ngram_details = [] attempts = 0 max_attempts = num_samples * 10 # Limit to prevent infinite loops. while len(valid_ngram_details) < num_samples and attempts < max_attempts: idx = random.choice(valid_positions) - ngram_token_ids = token_ids[idx: idx+ngram_size] + ngram_token_ids = token_ids[idx : idx + ngram_size] ngram_str = tokenizer.decode(ngram_token_ids, clean_up_tokenization_spaces=True) # Only accept n-grams that contain only allowed characters. if ALLOWED_RE.fullmatch(ngram_str) and len(ngram_str.strip()) > ngram_size * 3: @@ -101,10 +105,9 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam sample_count = len(valid_ngram_details) return doc_id, match_count, sample_count, valid_ngram_details + def main(): - parser = argparse.ArgumentParser( - description="Infini-gram n-gram matching script with Llama2 tokenization." - ) + parser = argparse.ArgumentParser(description="Infini-gram n-gram matching script with Llama2 tokenization.") parser.add_argument("N", type=int, help="Number of random .jsonl files to process") parser.add_argument("s3_path", type=str, help="S3 path to a prefix containing .jsonl files (e.g., s3://my-bucket/my-prefix/)") parser.add_argument("--index", type=str, default="v4_dolma-v1_7_llama", help="Infini-gram index to use (default: v4_rpj_llama_s4)") @@ -149,10 +152,8 @@ def main(): except Exception as e: print(f"Error parsing JSON in {key}: {e}") continue - doc_id, match_count, sample_count, details = process_document( - doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index - ) - + doc_id, match_count, sample_count, details = process_document(doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index) + # Print per-document n-gram summary print(f"\nDocument ID: {doc_id}") for flag, ngram in details: @@ -160,12 +161,13 @@ def main(): print(f"{flag:4} {repr(ngram)}") percentage = (match_count / sample_count * 100) if sample_count else 0 print(f"Matched n-grams: {match_count}/{sample_count} ({percentage:.2f}%)") - + total_matches += match_count total_ngrams_sampled += sample_count overall_percentage = (total_matches / total_ngrams_sampled * 100) if total_ngrams_sampled else 0 print(f"\nTotal matched n-grams: {total_matches}/{total_ngrams_sampled} ({overall_percentage:.2f}%)") + if __name__ == "__main__": main() diff --git a/tests/test_dataprep.py b/tests/test_dataprep.py index c13e30a..0aeba62 100644 --- a/tests/test_dataprep.py +++ b/tests/test_dataprep.py @@ -16,9 +16,7 @@ from tqdm import tqdm from transformers import AutoProcessor from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig -from olmocr.train.dataloader import ( - build_finetuning_dataset, -) +from olmocr.train.dataloader import build_finetuning_dataset from olmocr.train.dataprep import ( batch_prepare_data_for_molmo_training, build_finetuning_prompt, @@ -223,7 +221,6 @@ class TestMolmoDataPrep(unittest.TestCase): patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text, patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png, ): - # Set return values for the mocked functions mock_get_anchor_text.return_value = "This is the anchor text." # Create a red square image and encode it in base64 @@ -305,7 +302,6 @@ class TestMolmoDataPrep(unittest.TestCase): patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text, patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png, ): - mock_get_anchor_text.return_value = "This is the anchor text." img = Image.new("RGB", (100, 100), color="red") buffered = BytesIO()