mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-02 20:09:08 +00:00
fixed style
This commit is contained in:
parent
c2b54d8525
commit
0130a970c2
@ -12,14 +12,15 @@ The final score is averaged over the repeated generations.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
import sys
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from rapidfuzz import fuzz
|
||||
from fuzzysearch import find_near_matches
|
||||
from rapidfuzz import fuzz
|
||||
|
||||
|
||||
def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]):
|
||||
"""
|
||||
@ -30,7 +31,7 @@ def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]):
|
||||
|
||||
rules = []
|
||||
rule_ids = set()
|
||||
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
||||
with open(jsonl_path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
@ -75,6 +76,7 @@ def validate_jsonl_file(jsonl_path: str, all_pdf_files: list[str]):
|
||||
rules.append(data)
|
||||
return rules
|
||||
|
||||
|
||||
def run_rule(rule, md_file_path: str) -> (bool, str):
|
||||
"""
|
||||
Run the given rule on the content of the provided .md file.
|
||||
@ -82,7 +84,7 @@ def run_rule(rule, md_file_path: str) -> (bool, str):
|
||||
and 'explanation' is a short message explaining the failure when the rule does not pass.
|
||||
"""
|
||||
try:
|
||||
with open(md_file_path, 'r', encoding='utf-8') as f:
|
||||
with open(md_file_path, "r", encoding="utf-8") as f:
|
||||
md_content = f.read()
|
||||
except Exception as e:
|
||||
return (False, f"Error reading {md_file_path}: {e}")
|
||||
@ -121,15 +123,16 @@ def run_rule(rule, md_file_path: str) -> (bool, str):
|
||||
else:
|
||||
raise NotImplementedError(f"Rule type '{rule_type}' is not implemented.")
|
||||
|
||||
|
||||
def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: list[str]):
|
||||
"""
|
||||
For the candidate folder (pipeline tool output), validate that it contains at least one .md file
|
||||
(i.e. repeated generations like _1.md, _2.md, etc.) for every PDF in the pdf folder.
|
||||
Then, run each rule against all corresponding .md files and average the results.
|
||||
|
||||
|
||||
Returns a tuple:
|
||||
(overall_score, total_rules, candidate_errors, rule_failures, rule_type_breakdown)
|
||||
|
||||
|
||||
- overall_score: Average fraction of rules passed (averaged over repeats and rules).
|
||||
- total_rules: Total number of rules evaluated.
|
||||
- candidate_errors: List of candidate errors (e.g. missing files).
|
||||
@ -148,9 +151,7 @@ def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: li
|
||||
md_pattern = os.path.join(candidate_folder, f"{md_base}_*.md")
|
||||
md_files = glob.glob(md_pattern)
|
||||
if not md_files:
|
||||
candidate_errors.append(
|
||||
f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} (expected files matching {md_base}_*.md)."
|
||||
)
|
||||
candidate_errors.append(f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} (expected files matching {md_base}_*.md).")
|
||||
else:
|
||||
pdf_to_md_files[pdf_name] = md_files
|
||||
|
||||
@ -195,11 +196,14 @@ def evaluate_candidate(candidate_folder: str, all_rules: list, pdf_basenames: li
|
||||
overall_score = total_rule_score / len(all_rules) if all_rules else 0.0
|
||||
return (overall_score, len(all_rules), candidate_errors, rule_failures, rule_type_breakdown)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run OLMOCR Bench.")
|
||||
parser.add_argument("--input_folder",
|
||||
default=os.path.join(os.path.dirname(__file__), "sample_data"),
|
||||
help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.")
|
||||
parser.add_argument(
|
||||
"--input_folder",
|
||||
default=os.path.join(os.path.dirname(__file__), "sample_data"),
|
||||
help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
input_folder = args.input_folder
|
||||
@ -268,7 +272,7 @@ def main():
|
||||
print(f" Average Score: {overall_score * 100:.1f}% over {total_rules} rules.")
|
||||
|
||||
# Print final summary with breakdown by rule type
|
||||
print("\n" + "="*50)
|
||||
print("\n" + "=" * 50)
|
||||
print("Final Summary:")
|
||||
for candidate_name, overall_score, total_rules, candidate_errors, _, rule_type_breakdown in summary:
|
||||
if candidate_errors:
|
||||
@ -283,7 +287,8 @@ def main():
|
||||
else:
|
||||
avg = 0.0
|
||||
print(f" {rtype:8s}: {avg:0.1f}% average pass rate over {len(scores)} rules")
|
||||
print("="*50)
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,9 +1,11 @@
|
||||
import argparse
|
||||
import os
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def parse_method_arg(method_arg):
|
||||
"""
|
||||
Parse a method configuration string of the form:
|
||||
@ -29,22 +31,11 @@ def parse_method_arg(method_arg):
|
||||
raise ValueError(f"Extra argument '{extra}' is not in key=value format")
|
||||
return name, kwargs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run PDF conversion using specified OCR methods and extra parameters."
|
||||
)
|
||||
parser.add_argument(
|
||||
"methods",
|
||||
nargs="+",
|
||||
help="Methods to run in the format method[:key=value ...]. "
|
||||
"Example: gotocr mineru:temperature=2 marker:runs=3"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeats",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of times to repeat the conversion for each PDF."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Run PDF conversion using specified OCR methods and extra parameters.")
|
||||
parser.add_argument("methods", nargs="+", help="Methods to run in the format method[:key=value ...]. " "Example: gotocr mineru:temperature=2 marker:runs=3")
|
||||
parser.add_argument("--repeats", type=int, default=1, help="Number of times to repeat the conversion for each PDF.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Mapping of method names to a tuple: (module path, function name)
|
||||
@ -60,16 +51,12 @@ if __name__ == "__main__":
|
||||
for method_arg in args.methods:
|
||||
method_name, extra_kwargs = parse_method_arg(method_arg)
|
||||
if method_name not in available_methods:
|
||||
parser.error(f"Unknown method: {method_name}. "
|
||||
f"Available methods: {', '.join(available_methods.keys())}")
|
||||
parser.error(f"Unknown method: {method_name}. " f"Available methods: {', '.join(available_methods.keys())}")
|
||||
module_path, function_name = available_methods[method_name]
|
||||
# Dynamically import the module and get the function.
|
||||
module = importlib.import_module(module_path)
|
||||
function = getattr(module, function_name)
|
||||
config[method_name] = {
|
||||
"method": function,
|
||||
"kwargs": extra_kwargs
|
||||
}
|
||||
config[method_name] = {"method": function, "kwargs": extra_kwargs}
|
||||
|
||||
data_directory = os.path.join(os.path.dirname(__file__), "sample_data")
|
||||
pdf_directory = os.path.join(data_directory, "pdfs")
|
||||
|
@ -6,4 +6,4 @@
|
||||
|
||||
# Then, prompt again to generate the set of absent/present rules, given the diffs presented
|
||||
|
||||
# Then, run those rules through a tinyhost verification/edit system to quickly build up a big set
|
||||
# Then, run those rules through a tinyhost verification/edit system to quickly build up a big set
|
||||
|
@ -1,53 +1,49 @@
|
||||
# type: ignore
|
||||
import os
|
||||
import tempfile
|
||||
import base64
|
||||
import torch
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.data.anchor import get_anchor_text
|
||||
from olmocr.data.prompts import build_openai_silver_data_prompt
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
def run_chatgpt(pdf_path: str, page_num: int=1, model: str='gpt-4o-2024-08-06') -> str:
|
||||
from olmocr.data.anchor import get_anchor_text
|
||||
from olmocr.data.prompts import build_openai_silver_data_prompt
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
|
||||
|
||||
def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-06") -> str:
|
||||
"""
|
||||
Convert page of a PDF file to markdown using GOT-OCR.
|
||||
|
||||
|
||||
This function renders the first page of the PDF to an image, runs OCR on that image,
|
||||
and returns the OCR result as a markdown-formatted string.
|
||||
|
||||
|
||||
Args:
|
||||
pdf_path (str): The local path to the PDF file.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The OCR result in markdown format.
|
||||
"""
|
||||
# Convert the first page of the PDF to a base64-encoded PNG image.
|
||||
base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024)
|
||||
# base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024)
|
||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
||||
|
||||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages= [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
||||
],
|
||||
}
|
||||
],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
|
||||
],
|
||||
}
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=3000,
|
||||
logprobs=True,
|
||||
top_logprobs=5,
|
||||
response_format=openai_response_format_schema()
|
||||
response_format=openai_response_format_schema(),
|
||||
)
|
||||
print(response)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
@ -1,16 +1,18 @@
|
||||
import base64
|
||||
import os
|
||||
import tempfile
|
||||
import base64
|
||||
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
# Global cache for the model and tokenizer.
|
||||
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
_model = None
|
||||
_tokenizer = None
|
||||
|
||||
|
||||
def load_model():
|
||||
"""
|
||||
Load the GOT-OCR model and tokenizer if they haven't been loaded already.
|
||||
@ -20,50 +22,46 @@ def load_model():
|
||||
"""
|
||||
global _model, _tokenizer
|
||||
if _model is None or _tokenizer is None:
|
||||
_tokenizer = AutoTokenizer.from_pretrained(
|
||||
'ucaslcl/GOT-OCR2_0', trust_remote_code=True
|
||||
)
|
||||
_tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
|
||||
_model = AutoModel.from_pretrained(
|
||||
'ucaslcl/GOT-OCR2_0',
|
||||
"ucaslcl/GOT-OCR2_0",
|
||||
trust_remote_code=True,
|
||||
use_safetensors=True,
|
||||
revision="979938bf89ccdc949c0131ddd3841e24578a4742",
|
||||
pad_token_id=_tokenizer.eos_token_id
|
||||
pad_token_id=_tokenizer.eos_token_id,
|
||||
)
|
||||
_model = _model.eval().to(_device)
|
||||
return _model, _tokenizer
|
||||
|
||||
|
||||
def run_gotocr(pdf_path: str, page_num: int=1, ocr_type: str='ocr') -> str:
|
||||
def run_gotocr(pdf_path: str, page_num: int = 1, ocr_type: str = "ocr") -> str:
|
||||
"""
|
||||
Convert page of a PDF file to markdown using GOT-OCR.
|
||||
|
||||
|
||||
This function renders the first page of the PDF to an image, runs OCR on that image,
|
||||
and returns the OCR result as a markdown-formatted string.
|
||||
|
||||
|
||||
Args:
|
||||
pdf_path (str): The local path to the PDF file.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The OCR result in markdown format.
|
||||
"""
|
||||
# Ensure the model is loaded (cached across calls)
|
||||
model, tokenizer = load_model()
|
||||
|
||||
|
||||
# Convert the first page of the PDF to a base64-encoded PNG image.
|
||||
base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024)
|
||||
|
||||
|
||||
# Write the image to a temporary file.
|
||||
with tempfile.NamedTemporaryFile("wb", suffix=".png", delete=False) as tmp:
|
||||
tmp.write(base64.b64decode(base64image))
|
||||
tmp_filename = tmp.name
|
||||
|
||||
|
||||
# Run GOT-OCR on the saved image.
|
||||
result = model.chat(tokenizer, tmp_filename, ocr_type=ocr_type)
|
||||
|
||||
|
||||
# Clean up the temporary file.
|
||||
os.remove(tmp_filename)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
@ -1,15 +1,11 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
|
||||
|
||||
from marker.converters.pdf import PdfConverter
|
||||
from marker.models import create_model_dict
|
||||
from marker.output import text_from_rendered
|
||||
|
||||
_marker_converter = None
|
||||
|
||||
def run_marker(pdf_path: str, page_num: int=1) -> str:
|
||||
|
||||
def run_marker(pdf_path: str, page_num: int = 1) -> str:
|
||||
global _marker_converter
|
||||
|
||||
if _marker_converter is None:
|
||||
@ -22,4 +18,3 @@ def run_marker(pdf_path: str, page_num: int=1) -> str:
|
||||
text, _, images = text_from_rendered(rendered)
|
||||
|
||||
return text
|
||||
|
||||
|
@ -1,15 +1,13 @@
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import tempfile
|
||||
|
||||
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
|
||||
from magic_pdf.config.enums import SupportedPdfParseMethod
|
||||
from magic_pdf.data.data_reader_writer import FileBasedDataReader, FileBasedDataWriter
|
||||
from magic_pdf.data.dataset import PymuDocDataset
|
||||
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||||
from magic_pdf.config.enums import SupportedPdfParseMethod
|
||||
|
||||
|
||||
def run_mineru(pdf_path: str, page_num: int=1) -> str:
|
||||
def run_mineru(pdf_path: str, page_num: int = 1) -> str:
|
||||
output_folder = tempfile.TemporaryDirectory()
|
||||
image_output_folder = tempfile.TemporaryDirectory()
|
||||
|
||||
@ -34,7 +32,7 @@ def run_mineru(pdf_path: str, page_num: int=1) -> str:
|
||||
|
||||
# Generate markdown content; the image directory is the basename of the images output folder
|
||||
image_dir_basename = os.path.basename(image_output_folder.name)
|
||||
md_content = pipe_result.get_markdown(image_dir_basename)
|
||||
# md_content = pipe_result.get_markdown(image_dir_basename)
|
||||
|
||||
# Dump markdown file
|
||||
with tempfile.NamedTemporaryFile("w+", suffix="md") as tf:
|
||||
@ -45,4 +43,3 @@ def run_mineru(pdf_path: str, page_num: int=1) -> str:
|
||||
md_data = tf.read()
|
||||
|
||||
return md_data
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import olmocr.pipeline
|
||||
|
||||
# Set sys.argv as if you were running the script from the command line.
|
||||
@ -11,9 +12,10 @@ import olmocr.pipeline
|
||||
workspace_dir = "olmocr/bench/sample_data/olmocr/workspace"
|
||||
|
||||
sys.argv = [
|
||||
"pipeline.py", # The script name (can be arbitrary)
|
||||
"olmocr/bench/sample_data/olmocr/workspace", # Positional argument: workspace
|
||||
"--pdfs", *list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths
|
||||
"pipeline.py", # The script name (can be arbitrary)
|
||||
"olmocr/bench/sample_data/olmocr/workspace", # Positional argument: workspace
|
||||
"--pdfs",
|
||||
*list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths
|
||||
]
|
||||
|
||||
# Call the async main() function.
|
||||
@ -30,4 +32,4 @@ for jsonl_path in glob.glob(workspace_dir + "/results/*.jsonl"):
|
||||
with open(f"olmocr/bench/sample_data/olmocr/{name.replace('.pdf', '.md')}", "w") as out_f:
|
||||
out_f.write(data["text"])
|
||||
|
||||
shutil.rmtree(workspace_dir)
|
||||
shutil.rmtree(workspace_dir)
|
||||
|
@ -100,7 +100,6 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool, pd
|
||||
|
||||
try:
|
||||
with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile:
|
||||
|
||||
for line_number, line in enumerate(infile, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
|
@ -34,7 +34,6 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
|
||||
|
||||
try:
|
||||
with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile:
|
||||
|
||||
for line_number, line in enumerate(infile, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
|
@ -168,5 +168,4 @@ class BenchmarkRepeatDetect(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
unittest.main()
|
||||
|
@ -1 +1,2 @@
|
||||
class DolmaRefineError(RuntimeError): ...
|
||||
class DolmaRefineError(RuntimeError):
|
||||
...
|
||||
|
@ -5,10 +5,7 @@ import re
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
from datasets import (
|
||||
Dataset,
|
||||
load_dataset,
|
||||
)
|
||||
from datasets import Dataset, load_dataset
|
||||
from filelock import FileLock
|
||||
|
||||
from olmocr.data.renderpdf import get_pdf_media_box_width_height
|
||||
|
@ -4,17 +4,11 @@ from io import BytesIO
|
||||
import torch
|
||||
import torch.distributed
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.prompts.prompts import (
|
||||
build_openai_silver_data_prompt,
|
||||
)
|
||||
from olmocr.prompts.prompts import build_openai_silver_data_prompt
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -9,11 +9,7 @@ import torchvision.transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import convert_image_dtype
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ImageInput,
|
||||
)
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput
|
||||
from transformers.processing_utils import ImagesKwargs
|
||||
from transformers.utils import logging
|
||||
|
||||
|
@ -740,7 +740,6 @@ class ViTMLP(nn.Module):
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: FullMolmoConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -772,7 +771,6 @@ class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
|
||||
class BlockCollection(nn.Module):
|
||||
|
||||
def __init__(self, config: FullMolmoConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -801,7 +799,6 @@ class LayerNormFp32(nn.LayerNorm):
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self, config: FullMolmoConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -952,7 +949,6 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
||||
|
||||
def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if inputs_kv is not None:
|
||||
inputs_k = inputs_kv
|
||||
inputs_v = inputs_kv
|
||||
@ -1099,7 +1095,6 @@ class MultiHeadAttentionPool(nn.Module):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
|
||||
|
||||
def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
xk, xv = self.wk(inputs_kv), self.wv(inputs_kv)
|
||||
|
||||
if self.query == "mean":
|
||||
|
@ -16,11 +16,7 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import (
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
TextKwargs,
|
||||
)
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.utils import logging
|
||||
|
||||
|
@ -147,7 +147,6 @@ def run_train(config: TrainConfig):
|
||||
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
|
||||
|
||||
with TemporaryDirectory() as output_dir:
|
||||
|
||||
training_args = TrainingArguments(
|
||||
run_name=run_name.run,
|
||||
logging_steps=config.hparams.log_every_steps,
|
||||
|
@ -224,7 +224,6 @@ class TruncatingCollator:
|
||||
|
||||
@contextmanager
|
||||
def get_local_dir(output_dir: str):
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
if is_local(output_dir):
|
||||
yield output_dir
|
||||
|
@ -362,7 +362,6 @@ async def run_vllm_async(
|
||||
)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
|
@ -1,34 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import boto3
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import requests
|
||||
import time
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Allowed characters: alphanumeric, space, and basic punctuation ".,!?()"
|
||||
ALLOWED_RE = re.compile(r'^[A-Za-z0-9\.,!?() ]+$')
|
||||
ALLOWED_RE = re.compile(r"^[A-Za-z0-9\.,!?() ]+$")
|
||||
|
||||
|
||||
def get_random_line_from_s3(bucket, key):
|
||||
"""
|
||||
Reads an S3 object line-by-line and returns a random line using reservoir sampling.
|
||||
"""
|
||||
s3 = boto3.client('s3')
|
||||
s3 = boto3.client("s3")
|
||||
response = s3.get_object(Bucket=bucket, Key=key)
|
||||
random_line = None
|
||||
count = 0
|
||||
for line in response['Body'].iter_lines():
|
||||
for line in response["Body"].iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line_str = line.decode('utf-8')
|
||||
line_str = line.decode("utf-8")
|
||||
count += 1
|
||||
if random.randint(1, count) == 1:
|
||||
random_line = line_str
|
||||
return random_line
|
||||
|
||||
|
||||
def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3):
|
||||
"""
|
||||
Sends a count query to the infini-gram API for the given n-gram.
|
||||
@ -47,17 +50,18 @@ def query_infinigram(ngram, index="v4_rpj_llama_s4", retries=3):
|
||||
result = response.json()
|
||||
if "count" in result:
|
||||
return result["count"]
|
||||
except Exception as e:
|
||||
except Exception as e: # type: ignore
|
||||
time.sleep(1)
|
||||
return 0
|
||||
|
||||
|
||||
def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llama_s4"):
|
||||
"""
|
||||
Tokenizes the document using the Llama2 tokenizer and samples random n-grams.
|
||||
Each n-gram is chosen such that:
|
||||
1. It starts on a word-split boundary (using the offset mapping and a check on the preceding character).
|
||||
2. Its decoded string contains only alphanumeric characters, spaces, and the punctuation marks ".,!?()".
|
||||
|
||||
|
||||
Each valid n-gram is then queried using the infini-gram API.
|
||||
The function returns the document id, the number of matching n-grams (i.e. API count > 0),
|
||||
the total number of valid n-grams sampled, and a list of tuples (flag, ngram_string).
|
||||
@ -67,7 +71,7 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam
|
||||
# Get tokenized representation with offset mapping to determine word boundaries.
|
||||
tokenized = tokenizer(text, add_special_tokens=False, return_offsets_mapping=True)
|
||||
token_ids = tokenized["input_ids"]
|
||||
offsets = tokenized["offset_mapping"]
|
||||
# offsets = tokenized["offset_mapping"]
|
||||
|
||||
if len(token_ids) < ngram_size:
|
||||
return doc_id, 0, 0, []
|
||||
@ -78,17 +82,17 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam
|
||||
# start_offset = offsets[i][0]
|
||||
# if start_offset == 0 or (start_offset > 0 and text[start_offset - 1] == " "):
|
||||
# valid_positions.append(i)
|
||||
|
||||
|
||||
if not valid_positions:
|
||||
# Fallback: if no valid positions are found, use all possible positions.
|
||||
valid_positions = list(range(len(token_ids) - ngram_size + 1))
|
||||
|
||||
|
||||
valid_ngram_details = []
|
||||
attempts = 0
|
||||
max_attempts = num_samples * 10 # Limit to prevent infinite loops.
|
||||
while len(valid_ngram_details) < num_samples and attempts < max_attempts:
|
||||
idx = random.choice(valid_positions)
|
||||
ngram_token_ids = token_ids[idx: idx+ngram_size]
|
||||
ngram_token_ids = token_ids[idx : idx + ngram_size]
|
||||
ngram_str = tokenizer.decode(ngram_token_ids, clean_up_tokenization_spaces=True)
|
||||
# Only accept n-grams that contain only allowed characters.
|
||||
if ALLOWED_RE.fullmatch(ngram_str) and len(ngram_str.strip()) > ngram_size * 3:
|
||||
@ -101,10 +105,9 @@ def process_document(doc, tokenizer, ngram_size, num_samples, index="v4_rpj_llam
|
||||
sample_count = len(valid_ngram_details)
|
||||
return doc_id, match_count, sample_count, valid_ngram_details
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Infini-gram n-gram matching script with Llama2 tokenization."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Infini-gram n-gram matching script with Llama2 tokenization.")
|
||||
parser.add_argument("N", type=int, help="Number of random .jsonl files to process")
|
||||
parser.add_argument("s3_path", type=str, help="S3 path to a prefix containing .jsonl files (e.g., s3://my-bucket/my-prefix/)")
|
||||
parser.add_argument("--index", type=str, default="v4_dolma-v1_7_llama", help="Infini-gram index to use (default: v4_rpj_llama_s4)")
|
||||
@ -149,10 +152,8 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"Error parsing JSON in {key}: {e}")
|
||||
continue
|
||||
doc_id, match_count, sample_count, details = process_document(
|
||||
doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index
|
||||
)
|
||||
|
||||
doc_id, match_count, sample_count, details = process_document(doc, tokenizer, args.ngram_size, args.num_ngrams, index=args.index)
|
||||
|
||||
# Print per-document n-gram summary
|
||||
print(f"\nDocument ID: {doc_id}")
|
||||
for flag, ngram in details:
|
||||
@ -160,12 +161,13 @@ def main():
|
||||
print(f"{flag:4} {repr(ngram)}")
|
||||
percentage = (match_count / sample_count * 100) if sample_count else 0
|
||||
print(f"Matched n-grams: {match_count}/{sample_count} ({percentage:.2f}%)")
|
||||
|
||||
|
||||
total_matches += match_count
|
||||
total_ngrams_sampled += sample_count
|
||||
|
||||
overall_percentage = (total_matches / total_ngrams_sampled * 100) if total_ngrams_sampled else 0
|
||||
print(f"\nTotal matched n-grams: {total_matches}/{total_ngrams_sampled} ({overall_percentage:.2f}%)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -16,9 +16,7 @@ from tqdm import tqdm
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig
|
||||
from olmocr.train.dataloader import (
|
||||
build_finetuning_dataset,
|
||||
)
|
||||
from olmocr.train.dataloader import build_finetuning_dataset
|
||||
from olmocr.train.dataprep import (
|
||||
batch_prepare_data_for_molmo_training,
|
||||
build_finetuning_prompt,
|
||||
@ -223,7 +221,6 @@ class TestMolmoDataPrep(unittest.TestCase):
|
||||
patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
|
||||
patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
|
||||
):
|
||||
|
||||
# Set return values for the mocked functions
|
||||
mock_get_anchor_text.return_value = "This is the anchor text."
|
||||
# Create a red square image and encode it in base64
|
||||
@ -305,7 +302,6 @@ class TestMolmoDataPrep(unittest.TestCase):
|
||||
patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
|
||||
patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
|
||||
):
|
||||
|
||||
mock_get_anchor_text.return_value = "This is the anchor text."
|
||||
img = Image.new("RGB", (100, 100), color="red")
|
||||
buffered = BytesIO()
|
||||
|
Loading…
x
Reference in New Issue
Block a user