From 4a1762d4551ef034b21ed9bb46d6b904fbb47cfe Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 29 Jan 2025 15:25:10 -0800 Subject: [PATCH] isort --- olmocr/check.py | 6 +- olmocr/data/buildsilver.py | 24 +++--- olmocr/data/buildsilverdatasummary.py | 10 ++- olmocr/data/buildtestset.py | 15 ++-- olmocr/data/convertsilver_birr.py | 19 ++--- olmocr/data/convertsilver_openai.py | 16 ++-- olmocr/data/renderpdf.py | 5 +- olmocr/data/runopenaibatch.py | 9 ++- olmocr/eval/buildelo.py | 13 ++-- olmocr/eval/evalhtml.py | 8 +- olmocr/eval/runeval.py | 29 ++++--- olmocr/eval/scoreelo.py | 8 +- olmocr/filter/coherency.py | 2 +- olmocr/filter/filter.py | 6 +- olmocr/metrics.py | 5 +- olmocr/pipeline.py | 80 +++++++++++--------- olmocr/prompts/_adv_anchor.py | 20 ++--- olmocr/prompts/anchor.py | 16 ++-- olmocr/prompts/prompts.py | 1 + olmocr/repeatdetect.py | 3 +- olmocr/s3_utils.py | 33 ++++---- olmocr/train/buildparquetdataset.py | 6 +- olmocr/train/core/paths.py | 5 +- olmocr/train/dataloader.py | 34 +++++---- olmocr/train/dataprep.py | 15 ++-- olmocr/train/fixqwen2vlcheckpoint.py | 12 +-- olmocr/train/inference.py | 26 +++---- olmocr/train/loaddataset.py | 14 ++-- olmocr/train/molmo/config_molmo.py | 2 +- olmocr/train/molmo/image_processing_molmo.py | 8 +- olmocr/train/molmo/modeling_molmo.py | 24 ++++-- olmocr/train/molmo/preprocessing_molmo.py | 11 +-- olmocr/train/train.py | 28 +++---- olmocr/train/utils.py | 13 ++-- olmocr/viewer/dolmaviewer.py | 20 ++--- olmocr/work_queue.py | 18 ++--- scripts/benchmark_throughput.py | 19 +++-- scripts/movedolmadocs_to_md.py | 8 +- tests/test_anchor.py | 13 ++-- tests/test_coherency.py | 3 +- tests/test_dataloader.py | 7 +- tests/test_dataprep.py | 29 +++---- tests/test_molmo.py | 10 ++- tests/test_s3_work_queue.py | 10 ++- tests/test_sglang.py | 39 ++++++---- 45 files changed, 389 insertions(+), 313 deletions(-) diff --git a/olmocr/check.py b/olmocr/check.py index 2f3c641..8720eca 100644 --- a/olmocr/check.py +++ b/olmocr/check.py @@ -1,7 +1,7 @@ -import sys -import subprocess -import logging import importlib.util +import logging +import subprocess +import sys logger = logging.getLogger(__name__) diff --git a/olmocr/data/buildsilver.py b/olmocr/data/buildsilver.py index fd2269f..298966d 100644 --- a/olmocr/data/buildsilver.py +++ b/olmocr/data/buildsilver.py @@ -1,21 +1,25 @@ -import os +import argparse +import base64 import glob +import json +import os import random import subprocess -import base64 -import argparse -import boto3 -import json -from pypdf import PdfReader -from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from typing import Generator -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed from urllib.parse import urlparse +import boto3 +from pypdf import PdfReader +from tqdm import tqdm + from olmocr.data.renderpdf import render_pdf_to_base64png -from olmocr.prompts import build_openai_silver_data_prompt, openai_response_format_schema -from olmocr.prompts.anchor import get_anchor_text from olmocr.filter import PdfFilter +from olmocr.prompts import ( + build_openai_silver_data_prompt, + openai_response_format_schema, +) +from olmocr.prompts.anchor import get_anchor_text TARGET_IMAGE_DIM = 2048 diff --git a/olmocr/data/buildsilverdatasummary.py b/olmocr/data/buildsilverdatasummary.py index 6150ac8..3ad8c4f 100644 --- a/olmocr/data/buildsilverdatasummary.py +++ b/olmocr/data/buildsilverdatasummary.py @@ -1,15 +1,17 @@ -import os +import argparse +import collections import csv import json -import argparse -import re -import collections +import os import random +import re import sqlite3 from concurrent.futures import ProcessPoolExecutor, as_completed from urllib.parse import urlparse + from tqdm import tqdm + def parse_pdf_hash(pretty_pdf_path: str) -> str: pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+" match = re.match(pattern, pretty_pdf_path) diff --git a/olmocr/data/buildtestset.py b/olmocr/data/buildtestset.py index 3cf2128..dfc3381 100644 --- a/olmocr/data/buildtestset.py +++ b/olmocr/data/buildtestset.py @@ -1,14 +1,15 @@ -import os -import glob -import random import argparse -import boto3 import base64 +import glob +import os +import random +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import List +from urllib.parse import urlparse + +import boto3 from pypdf import PdfReader, PdfWriter from tqdm import tqdm -from concurrent.futures import ProcessPoolExecutor, as_completed -from urllib.parse import urlparse -from typing import List from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.filter import PdfFilter diff --git a/olmocr/data/convertsilver_birr.py b/olmocr/data/convertsilver_birr.py index 83b56e7..3f119ce 100644 --- a/olmocr/data/convertsilver_birr.py +++ b/olmocr/data/convertsilver_birr.py @@ -1,21 +1,22 @@ import argparse import json -import re -from pathlib import Path -from concurrent.futures import ProcessPoolExecutor, as_completed -import sys import logging -import tempfile import os +import re +import sys +import tempfile +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path -import smart_open import boto3 -from olmocr.prompts import build_finetuning_prompt -from olmocr.prompts.anchor import get_anchor_text -from olmocr.data.renderpdf import render_pdf_to_base64png # Import Plotly for plotting import plotly.express as px +import smart_open + +from olmocr.data.renderpdf import render_pdf_to_base64png +from olmocr.prompts import build_finetuning_prompt +from olmocr.prompts.anchor import get_anchor_text def setup_logging(): diff --git a/olmocr/data/convertsilver_openai.py b/olmocr/data/convertsilver_openai.py index 4873a39..3e16286 100644 --- a/olmocr/data/convertsilver_openai.py +++ b/olmocr/data/convertsilver_openai.py @@ -1,14 +1,15 @@ import argparse import json -import re -from pathlib import Path -from concurrent.futures import ProcessPoolExecutor, as_completed -import sys -import os import logging +import os +import re +import sys +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path import smart_open from cached_path import cached_path + from olmocr.prompts import build_finetuning_prompt @@ -73,8 +74,8 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool): # Save the pdf to a temporary cache folder local_pdf_path = cached_path(s3_path, quiet=True) - from olmocr.prompts.anchor import get_anchor_text from olmocr.data.buildsilver import build_page_query + from olmocr.prompts.anchor import get_anchor_text obj = build_page_query(local_pdf_path, s3_path, page) # raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport") @@ -142,9 +143,10 @@ def list_input_files(input_dir): """ if is_s3_path(input_dir): # Use smart_open's s3 functionality to list files - import boto3 import fnmatch + import boto3 + # Parse bucket and prefix bucket_name = input_dir.split('s3://')[1].split('/')[0] path_and_pattern = '/'.join(input_dir.split('s3://')[1].split('/')[1:]) diff --git a/olmocr/data/renderpdf.py b/olmocr/data/renderpdf.py index b3bbec3..a3e04d4 100644 --- a/olmocr/data/renderpdf.py +++ b/olmocr/data/renderpdf.py @@ -1,8 +1,9 @@ -import subprocess import base64 import io -from pypdf import PdfReader +import subprocess + from PIL import Image +from pypdf import PdfReader def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[float, float]: diff --git a/olmocr/data/runopenaibatch.py b/olmocr/data/runopenaibatch.py index 579b62e..b989a66 100644 --- a/olmocr/data/runopenaibatch.py +++ b/olmocr/data/runopenaibatch.py @@ -1,15 +1,16 @@ # Sends list of batch files to OpenAI for processing # However, it also waits and gets the files when they are done, saves its state, and # allows you to submit more than the 100GB of file request limits that the openaiAPI has +import argparse +import datetime +import json import os import time -import json -import datetime -import argparse +from concurrent.futures import ThreadPoolExecutor, as_completed from enum import Enum + from openai import OpenAI from tqdm import tqdm -from concurrent.futures import ThreadPoolExecutor, as_completed # Set up OpenAI client (API key should be set in the environment) client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) diff --git a/olmocr/eval/buildelo.py b/olmocr/eval/buildelo.py index 370e58e..c4d8acd 100644 --- a/olmocr/eval/buildelo.py +++ b/olmocr/eval/buildelo.py @@ -1,19 +1,20 @@ import argparse -import boto3 import dataclasses +import functools import random import re from concurrent.futures import ProcessPoolExecutor, as_completed -import functools - -from tqdm import tqdm from itertools import combinations -from olmocr.s3_utils import parse_s3_path, expand_s3_glob, get_s3_bytes + +import boto3 +from dolma_refine.evaluate.aligners import HirschbergAligner from dolma_refine.evaluate.metrics import DocumentEditSimilarity from dolma_refine.evaluate.segmenters import SpacySegmenter -from dolma_refine.evaluate.aligners import HirschbergAligner +from tqdm import tqdm from olmocr.eval.evalhtml import create_review_html +from olmocr.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path + @dataclasses.dataclass class Comparison: diff --git a/olmocr/eval/evalhtml.py b/olmocr/eval/evalhtml.py index 0b5dffd..eb5090d 100644 --- a/olmocr/eval/evalhtml.py +++ b/olmocr/eval/evalhtml.py @@ -1,12 +1,14 @@ import os import random import tempfile -import boto3 from concurrent.futures import ThreadPoolExecutor -from jinja2 import Template -from urllib.parse import urlparse from difflib import SequenceMatcher +from urllib.parse import urlparse + +import boto3 +from jinja2 import Template from tqdm import tqdm + from olmocr.data.renderpdf import render_pdf_to_base64png session = boto3.Session(profile_name='s2') diff --git a/olmocr/eval/runeval.py b/olmocr/eval/runeval.py index e83a575..07931cf 100644 --- a/olmocr/eval/runeval.py +++ b/olmocr/eval/runeval.py @@ -3,29 +3,28 @@ # You might need to pip install git+https://github.com/allenai/refine.git@soldni/eval-m # in order to use some of the existing aligner scoring that was developed as part # of the refiner pipeline -import boto3 -import os -import json -import hashlib -import random -import zstandard -import sys import argparse - +import hashlib +import json +import logging +import os +import random +import sys +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from dataclasses import dataclass -from typing import Optional, Tuple, Dict -from tqdm import tqdm -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed from pathlib import Path -from smart_open import smart_open, register_compressor +from typing import Dict, Optional, Tuple + +import boto3 +import zstandard +from dolma_refine.evaluate.aligners import HirschbergAligner from dolma_refine.evaluate.metrics import DocumentEditSimilarity from dolma_refine.evaluate.segmenters import SpacySegmenter -from dolma_refine.evaluate.aligners import HirschbergAligner +from smart_open import register_compressor, smart_open +from tqdm import tqdm from .evalhtml import create_review_html -import logging - logging.getLogger("pypdf").setLevel(logging.ERROR) diff --git a/olmocr/eval/scoreelo.py b/olmocr/eval/scoreelo.py index 94ae2c0..1006a72 100644 --- a/olmocr/eval/scoreelo.py +++ b/olmocr/eval/scoreelo.py @@ -1,8 +1,10 @@ -import requests -import re -from urllib.parse import urlsplit, urlunsplit, parse_qs, urlencode import csv +import re from collections import defaultdict +from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit + +import requests + def fetch_review_page_html(url): """ diff --git a/olmocr/filter/coherency.py b/olmocr/filter/coherency.py index eb638e2..5139517 100644 --- a/olmocr/filter/coherency.py +++ b/olmocr/filter/coherency.py @@ -7,7 +7,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer @lru_cache() def load_coherency_model(model_name: str = "HuggingFaceTB/SmolLM-135M"): tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16) + model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) model.eval() # Set the model to evaluation mode return tokenizer, model diff --git a/olmocr/filter/filter.py b/olmocr/filter/filter.py index 35f05e9..b15c545 100644 --- a/olmocr/filter/filter.py +++ b/olmocr/filter/filter.py @@ -124,11 +124,13 @@ class PdfFilter: if __name__ == "__main__": import tempfile + from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait + import boto3 - from olmocr.s3_utils import parse_s3_path - from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED from tqdm import tqdm + from olmocr.s3_utils import parse_s3_path + # Quiet logs from pypdf logging.getLogger("pypdf").setLevel(logging.ERROR) diff --git a/olmocr/metrics.py b/olmocr/metrics.py index c5d2969..f4cfafe 100644 --- a/olmocr/metrics.py +++ b/olmocr/metrics.py @@ -1,9 +1,10 @@ -import time import asyncio -from collections import deque, defaultdict +import time +from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Dict + class MetricsKeeper: def __init__(self, window=60*5): """ diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index 7b6127e..ca4e734 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -1,47 +1,59 @@ -import logging import argparse -import boto3 -import signal -import os -import sys -import time -import subprocess +import asyncio +import atexit +import base64 +import datetime +import glob import hashlib import json -import base64 -import atexit -import asyncio -import httpx -import datetime -import tempfile -import random -import shutil -import re -import glob -import torch +import logging import multiprocessing - -from tqdm import tqdm -from urllib.parse import urlparse -from botocore.exceptions import ClientError -from io import BytesIO -from PIL import Image -from pypdf import PdfReader -from functools import partial, cache -from dataclasses import dataclass -from typing import Optional, Tuple, List, Dict, Set +import os +import random +import re +import shutil +import signal +import subprocess +import sys +import tempfile +import time from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from concurrent.futures.process import BrokenProcessPool +from dataclasses import dataclass +from functools import cache, partial +from io import BytesIO +from typing import Dict, List, Optional, Set, Tuple +from urllib.parse import urlparse -from olmocr.work_queue import WorkQueue, S3WorkQueue, LocalWorkQueue -from olmocr.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory +import boto3 +import httpx +import torch +from botocore.exceptions import ClientError +from PIL import Image +from pypdf import PdfReader +from tqdm import tqdm + +from olmocr.check import ( + check_poppler_version, + check_sglang_version, + check_torch_gpu_available, +) from olmocr.data.renderpdf import render_pdf_to_base64png -from olmocr.filter.filter import PdfFilter, Language -from olmocr.prompts import build_finetuning_prompt, PageResponse -from olmocr.prompts.anchor import get_anchor_text -from olmocr.check import check_poppler_version, check_sglang_version, check_torch_gpu_available +from olmocr.filter.filter import Language, PdfFilter from olmocr.metrics import MetricsKeeper, WorkerTracker +from olmocr.prompts import PageResponse, build_finetuning_prompt +from olmocr.prompts.anchor import get_anchor_text +from olmocr.s3_utils import ( + download_directory, + download_zstd_csv, + expand_s3_glob, + get_s3_bytes, + get_s3_bytes_with_backoff, + parse_s3_path, + upload_zstd_csv, +) from olmocr.version import VERSION +from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue # Initialize logger logger = logging.getLogger(__name__) diff --git a/olmocr/prompts/_adv_anchor.py b/olmocr/prompts/_adv_anchor.py index 777ab01..aa1f86c 100644 --- a/olmocr/prompts/_adv_anchor.py +++ b/olmocr/prompts/_adv_anchor.py @@ -1,11 +1,4 @@ import math -from pypdf.generic import ( - DictionaryObject, - FloatObject, - TextStringObject, - NumberObject, - NameObject -) from typing import ( Any, Callable, @@ -21,12 +14,21 @@ from typing import ( cast, overload, ) + +from pypdf._cmap import build_char_map, unknown_char_map from pypdf.constants import AnnotationDictionaryAttributes as ADA from pypdf.constants import ImageAttributes as IA from pypdf.constants import PageAttributes as PG from pypdf.constants import Resources as RES -from pypdf.generic import ContentStream, encode_pdfdocencoding -from pypdf._cmap import build_char_map, unknown_char_map +from pypdf.generic import ( + ContentStream, + DictionaryObject, + FloatObject, + NameObject, + NumberObject, + TextStringObject, + encode_pdfdocencoding, +) CUSTOM_RTL_MIN: int = -1 CUSTOM_RTL_MAX: int = -1 diff --git a/olmocr/prompts/anchor.py b/olmocr/prompts/anchor.py index dfb0935..7405ef8 100644 --- a/olmocr/prompts/anchor.py +++ b/olmocr/prompts/anchor.py @@ -7,22 +7,22 @@ # pymupdf # pypdf +import random +import re + # coherency score best of these three import subprocess -import re -import random -import ftfy from dataclasses import dataclass -from typing import Literal, List from functools import lru_cache +from typing import List, Literal -import pypdfium2 as pdfium +import ftfy import pymupdf - -from olmocr.filter.coherency import get_document_coherency - +import pypdfium2 as pdfium from pypdf import PdfReader from pypdf.generic import RectangleObject + +from olmocr.filter.coherency import get_document_coherency from olmocr.prompts._adv_anchor import mult diff --git a/olmocr/prompts/prompts.py b/olmocr/prompts/prompts.py index c4b5dce..011d837 100644 --- a/olmocr/prompts/prompts.py +++ b/olmocr/prompts/prompts.py @@ -2,6 +2,7 @@ import re from dataclasses import dataclass from typing import Optional + # This is the prompt we use for getting chat gpt 4o to convert documents into our silver training data def build_openai_silver_data_prompt(base_text: str) -> str: return ( diff --git a/olmocr/repeatdetect.py b/olmocr/repeatdetect.py index 06e1d7d..af8a6b5 100644 --- a/olmocr/repeatdetect.py +++ b/olmocr/repeatdetect.py @@ -1,7 +1,8 @@ -import unittest import random import string import time +import unittest + class RepeatDetector: def __init__(self, max_ngram_size: int = 10): diff --git a/olmocr/s3_utils.py b/olmocr/s3_utils.py index 2c5420a..c081ed2 100644 --- a/olmocr/s3_utils.py +++ b/olmocr/s3_utils.py @@ -1,26 +1,25 @@ -import os -import glob -import posixpath -import logging -import tempfile import base64 -import boto3 -import time -import requests import concurrent.futures +import glob import hashlib - -from urllib.parse import urlparse +import logging +import os +import posixpath +import tempfile +import time +from io import BytesIO, TextIOWrapper from pathlib import Path +from typing import List, Optional +from urllib.parse import urlparse + +import boto3 +import requests +import zstandard as zstd +from boto3.s3.transfer import TransferConfig +from botocore.config import Config +from botocore.exceptions import ClientError, NoCredentialsError from google.auth import compute_engine from google.cloud import storage -from botocore.config import Config -from botocore.exceptions import NoCredentialsError, ClientError -from boto3.s3.transfer import TransferConfig -from typing import Optional, List -from urllib.parse import urlparse -import zstandard as zstd -from io import BytesIO, TextIOWrapper from tqdm import tqdm logger = logging.getLogger(__name__) diff --git a/olmocr/train/buildparquetdataset.py b/olmocr/train/buildparquetdataset.py index ecf10cb..44f8e43 100644 --- a/olmocr/train/buildparquetdataset.py +++ b/olmocr/train/buildparquetdataset.py @@ -1,10 +1,12 @@ import argparse import logging -from functools import partial import os +from functools import partial + import boto3 -from datasets import Dataset from botocore.exceptions import NoCredentialsError, PartialCredentialsError +from datasets import Dataset + from olmocr.train.dataloader import build_batch_query_response_vision_dataset diff --git a/olmocr/train/core/paths.py b/olmocr/train/core/paths.py index ea6fb29..8201de7 100644 --- a/olmocr/train/core/paths.py +++ b/olmocr/train/core/paths.py @@ -1,14 +1,15 @@ import glob import os import re +from concurrent.futures import ThreadPoolExecutor from functools import partial, reduce from hashlib import sha256 from itertools import chain from pathlib import Path +from shutil import copyfileobj from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union from urllib.parse import urlparse -from concurrent.futures import ThreadPoolExecutor -from shutil import copyfileobj + import platformdirs import smart_open from fsspec import AbstractFileSystem, get_filesystem_class diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index 4f6857d..2d475a9 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -1,24 +1,32 @@ -import json -import logging -import tempfile -import re -import os import base64 import glob -import pypdf, pypdf.errors - +import json +import logging +import os +import re +import tempfile from functools import partial -from typing import Any, Dict, Optional from logging import Logger -from filelock import FileLock +from typing import Any, Dict, Optional import boto3 -from datasets import Dataset, Features, Value, load_dataset, concatenate_datasets, DatasetDict -from .core.config import DataConfig, SourceConfig +import pypdf +import pypdf.errors +from datasets import ( + Dataset, + DatasetDict, + Features, + Value, + concatenate_datasets, + load_dataset, +) +from filelock import FileLock -from olmocr.prompts.anchor import get_anchor_text -from olmocr.s3_utils import parse_custom_id, get_s3_bytes, parse_s3_path from olmocr.data.renderpdf import get_pdf_media_box_width_height +from olmocr.prompts.anchor import get_anchor_text +from olmocr.s3_utils import get_s3_bytes, parse_custom_id, parse_s3_path + +from .core.config import DataConfig, SourceConfig # Configure logging logging.basicConfig(level=logging.INFO) diff --git a/olmocr/train/dataprep.py b/olmocr/train/dataprep.py index 46b181c..14faf31 100644 --- a/olmocr/train/dataprep.py +++ b/olmocr/train/dataprep.py @@ -1,14 +1,15 @@ -import numpy as np -from io import BytesIO -from PIL import Image -from typing import Union import base64 import random -import torch # Make sure to import torch as it's used in the DataCollator +from io import BytesIO +from typing import Union + +import numpy as np +import torch # Make sure to import torch as it's used in the DataCollator +from PIL import Image -from olmocr.prompts.anchor import get_anchor_text -from olmocr.prompts import build_finetuning_prompt from olmocr.data.renderpdf import render_pdf_to_base64png +from olmocr.prompts import build_finetuning_prompt +from olmocr.prompts.anchor import get_anchor_text def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]): diff --git a/olmocr/train/fixqwen2vlcheckpoint.py b/olmocr/train/fixqwen2vlcheckpoint.py index 584e793..d8e5a23 100644 --- a/olmocr/train/fixqwen2vlcheckpoint.py +++ b/olmocr/train/fixqwen2vlcheckpoint.py @@ -1,14 +1,14 @@ import argparse -import os -import json -import torch -import boto3 -import tempfile import concurrent.futures +import json +import os +import tempfile +import boto3 +import torch from smart_open import smart_open - from transformers import Qwen2VLForConditionalGeneration + from olmocr.s3_utils import parse_s3_path s3_client = boto3.client('s3') diff --git a/olmocr/train/inference.py b/olmocr/train/inference.py index 798760d..a797045 100644 --- a/olmocr/train/inference.py +++ b/olmocr/train/inference.py @@ -1,37 +1,37 @@ -import os -import json import base64 +import json import logging +import os import time -from io import BytesIO -from PIL import Image from functools import partial +from io import BytesIO from logging import Logger from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional -from tqdm import tqdm import accelerate import torch import torch.distributed - +from PIL import Image +from tqdm import tqdm from transformers import ( + AutoConfig, AutoModelForCausalLM, + AutoProcessor, + Qwen2_5_VLForConditionalGeneration, + Qwen2VLForConditionalGeneration, Trainer, TrainerCallback, TrainingArguments, - Qwen2VLForConditionalGeneration, - Qwen2_5_VLForConditionalGeneration, - AutoProcessor, - AutoConfig, ) - from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.prompts.anchor import get_anchor_text -from olmocr.prompts.prompts import build_finetuning_prompt, build_openai_silver_data_prompt - +from olmocr.prompts.prompts import ( + build_finetuning_prompt, + build_openai_silver_data_prompt, +) @torch.no_grad() diff --git a/olmocr/train/loaddataset.py b/olmocr/train/loaddataset.py index c3c296f..6cd75e2 100644 --- a/olmocr/train/loaddataset.py +++ b/olmocr/train/loaddataset.py @@ -1,16 +1,12 @@ -from transformers import ( - AutoProcessor, - DataCollatorForSeq2Seq -) +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoProcessor, DataCollatorForSeq2Seq from olmocr.train.core.cli import make_cli from olmocr.train.core.config import TrainConfig -from tqdm import tqdm -from .utils import ( - make_dataset, TruncatingCollator -) -from torch.utils.data import DataLoader +from .utils import TruncatingCollator, make_dataset + def main(): train_config = make_cli(TrainConfig) # pyright: ignore diff --git a/olmocr/train/molmo/config_molmo.py b/olmocr/train/molmo/config_molmo.py index 2322810..3221353 100644 --- a/olmocr/train/molmo/config_molmo.py +++ b/olmocr/train/molmo/config_molmo.py @@ -1,6 +1,6 @@ from typing import List -from transformers import PretrainedConfig, AutoTokenizer +from transformers import AutoTokenizer, PretrainedConfig class MolmoConfig(PretrainedConfig): diff --git a/olmocr/train/molmo/image_processing_molmo.py b/olmocr/train/molmo/image_processing_molmo.py index ef787ec..0043bf7 100644 --- a/olmocr/train/molmo/image_processing_molmo.py +++ b/olmocr/train/molmo/image_processing_molmo.py @@ -1,13 +1,13 @@ """Image processor class for Molmo""" -from typing import List, Optional, Union, Mapping +from typing import List, Mapping, Optional, Union -import numpy as np import einops +import numpy as np import torch import torchvision.transforms from torchvision.transforms import InterpolationMode from torchvision.transforms.functional import convert_image_dtype - +from transformers.image_processing_utils import BaseImageProcessor from transformers.image_utils import ( OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, @@ -15,10 +15,8 @@ from transformers.image_utils import ( is_valid_image, ) from transformers.processing_utils import ImagesKwargs -from transformers.image_processing_utils import BaseImageProcessor from transformers.utils import logging - logger = logging.get_logger(__name__) diff --git a/olmocr/train/molmo/modeling_molmo.py b/olmocr/train/molmo/modeling_molmo.py index 606cbf8..31f7952 100644 --- a/olmocr/train/molmo/modeling_molmo.py +++ b/olmocr/train/molmo/modeling_molmo.py @@ -1,21 +1,31 @@ import logging import math from copy import deepcopy -from dataclasses import fields, dataclass, replace +from dataclasses import dataclass, fields, replace from enum import Enum -from typing import List, Optional, Tuple, Union, Dict, Any, Sequence, Callable, cast, MutableMapping +from typing import ( + Any, + Callable, + Dict, + List, + MutableMapping, + Optional, + Sequence, + Tuple, + Union, + cast, +) import torch -from einops import einsum, einops -from transformers import PreTrainedModel, GenerationConfig +from einops import einops, einsum +from torch import nn +from torch.nn import functional as F +from transformers import GenerationConfig, PreTrainedModel from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from transformers.models.auto import AutoModelForCausalLM -from torch import nn from .config_molmo import MolmoConfig -from torch.nn import functional as F - log = logging.getLogger(__name__) diff --git a/olmocr/train/molmo/preprocessing_molmo.py b/olmocr/train/molmo/preprocessing_molmo.py index 3598399..acc2fed 100644 --- a/olmocr/train/molmo/preprocessing_molmo.py +++ b/olmocr/train/molmo/preprocessing_molmo.py @@ -15,20 +15,17 @@ except ImportError: import numpy as np import torch - +from transformers import AutoTokenizer from transformers.image_utils import ImageInput from transformers.processing_utils import ( - TextKwargs, ProcessingKwargs, ProcessorMixin, + TextKwargs, ) - -from transformers.tokenization_utils_base import TextInput, PreTokenizedInput +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging -from transformers import AutoTokenizer -from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor - +from .image_preprocessing_molmo import MolmoImageProcessor, MolmoImagesKwargs logger = logging.get_logger(__name__) diff --git a/olmocr/train/train.py b/olmocr/train/train.py index cdac627..b76ece5 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -1,40 +1,39 @@ -import os -import json import base64 +import json import logging -import time +import os import random -from io import BytesIO -from PIL import Image +import time from functools import partial +from io import BytesIO from logging import Logger from pathlib import Path from tempfile import TemporaryDirectory from typing import Optional -from tqdm import tqdm import accelerate import torch import torch.distributed +import wandb from datasets import DatasetDict, concatenate_datasets from datasets.utils import disable_progress_bars from datasets.utils.logging import set_verbosity from peft import LoraConfig, get_peft_model # pyright: ignore +from PIL import Image +from torch.utils.data import DataLoader +from tqdm import tqdm from transformers import ( + AutoConfig, AutoModelForCausalLM, + AutoProcessor, + Qwen2VLForConditionalGeneration, Trainer, TrainerCallback, TrainingArguments, - Qwen2VLForConditionalGeneration, - AutoProcessor, - AutoConfig, ) from transformers.integrations import WandbCallback from transformers.trainer_callback import TrainerControl, TrainerState from transformers.trainer_utils import get_last_checkpoint -from torch.utils.data import DataLoader - -import wandb from olmocr.train.core.cli import make_cli, save_config, to_native_types from olmocr.train.core.config import TrainConfig @@ -44,13 +43,14 @@ from olmocr.train.core.state import BeakerState from .utils import ( RunName, + TruncatingCollator, get_local_dir, log_trainable_parameters, - setup_environment, make_dataset, - TruncatingCollator + setup_environment, ) + class CheckpointUploadCallback(TrainerCallback): def __init__(self, save_path: str, logger: Optional[Logger] = None): self.save_path = save_path diff --git a/olmocr/train/utils.py b/olmocr/train/utils.py index 0ab9a2c..7128457 100644 --- a/olmocr/train/utils.py +++ b/olmocr/train/utils.py @@ -5,31 +5,34 @@ import random from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime +from functools import partial from hashlib import sha1 from logging import Logger from tempfile import TemporaryDirectory from typing import Dict, Generator, List, Optional, TypeVar -from functools import partial - import torch import torch.nn.functional as F -from transformers import AutoProcessor from accelerate import Accelerator from accelerate.utils import PrecisionType from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset +from transformers import AutoProcessor from .core.cli import to_native_types -from .core.config import AwsConfig, TrainConfig, WandbConfig, DataConfig, SourceConfig +from .core.config import AwsConfig, DataConfig, SourceConfig, TrainConfig, WandbConfig from .core.loggers import get_logger from .core.paths import copy_dir, is_local from .core.state import BeakerState + # from .tokenization import ModelTokenizer T = TypeVar("T") from olmocr.train.dataloader import build_finetuning_dataset, list_dataset_files -from olmocr.train.dataprep import batch_prepare_data_for_qwen2_training, batch_prepare_data_for_molmo_training +from olmocr.train.dataprep import ( + batch_prepare_data_for_molmo_training, + batch_prepare_data_for_qwen2_training, +) def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype: diff --git a/olmocr/viewer/dolmaviewer.py b/olmocr/viewer/dolmaviewer.py index 4bef355..4fe0acc 100644 --- a/olmocr/viewer/dolmaviewer.py +++ b/olmocr/viewer/dolmaviewer.py @@ -1,19 +1,21 @@ -import os -import json -import html import argparse -import boto3 -import tempfile import glob +import html +import json +import os +import tempfile +from concurrent.futures import ThreadPoolExecutor, as_completed + +import boto3 +import markdown2 +import smart_open from botocore.exceptions import NoCredentialsError, PartialCredentialsError from jinja2 import Template -import smart_open from tqdm import tqdm -from concurrent.futures import ThreadPoolExecutor, as_completed -import markdown2 -from olmocr.s3_utils import get_s3_bytes, parse_s3_path from olmocr.data.renderpdf import render_pdf_to_base64webp +from olmocr.s3_utils import get_s3_bytes, parse_s3_path + def read_jsonl(paths): """ diff --git a/olmocr/work_queue.py b/olmocr/work_queue.py index 8fa487c..4b4b742 100644 --- a/olmocr/work_queue.py +++ b/olmocr/work_queue.py @@ -1,15 +1,14 @@ +import abc +import asyncio +import datetime +import hashlib +import logging import os import random -import logging -import hashlib import tempfile -import datetime -import asyncio -import abc -from typing import Optional, List, Dict, Set from dataclasses import dataclass - from functools import partial +from typing import Dict, List, Optional, Set logger = logging.getLogger(__name__) @@ -373,12 +372,13 @@ class LocalWorkQueue(WorkQueue): # -------------------------------------------------------------------------------------- from olmocr.s3_utils import ( - expand_s3_glob, download_zstd_csv, + expand_s3_glob, + parse_s3_path, upload_zstd_csv, - parse_s3_path ) + class S3WorkQueue(WorkQueue): """ Manages a work queue stored in S3 that coordinates work across multiple workers. diff --git a/scripts/benchmark_throughput.py b/scripts/benchmark_throughput.py index 7ebb052..b5384bb 100644 --- a/scripts/benchmark_throughput.py +++ b/scripts/benchmark_throughput.py @@ -1,24 +1,27 @@ """Benchmark offline inference throughput.""" import argparse +import base64 import json import random import time -import base64 - -from typing import List, Optional, Tuple -from PIL import Image from io import BytesIO +from typing import List, Optional, Tuple import torch import uvloop +from PIL import Image from tqdm import tqdm -from transformers import (AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizerBase, AutoProcessor) - +from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizerBase, +) from vllm import TokensPrompt from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, +) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser, merge_async_iterators diff --git a/scripts/movedolmadocs_to_md.py b/scripts/movedolmadocs_to_md.py index e41ca9d..8044490 100644 --- a/scripts/movedolmadocs_to_md.py +++ b/scripts/movedolmadocs_to_md.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 -import os +import argparse import io import json -import boto3 -import argparse +import os from urllib.parse import urlparse +import boto3 + + def parse_args(): parser = argparse.ArgumentParser(description="Read JSONL files from an S3 prefix, extract text, and write to local .md files.") parser.add_argument("--s3-prefix", diff --git a/tests/test_anchor.py b/tests/test_anchor.py index 525bbf9..4abe8cf 100644 --- a/tests/test_anchor.py +++ b/tests/test_anchor.py @@ -1,13 +1,14 @@ -import unittest -import os -import json -import io import glob +import io +import json +import os +import unittest from pypdf import PdfReader -from olmocr.prompts.anchor import _pdf_report, _linearize_pdf_report, get_anchor_text from olmocr.data.renderpdf import get_pdf_media_box_width_height +from olmocr.prompts.anchor import _linearize_pdf_report, _pdf_report, get_anchor_text + class AnchorTest(unittest.TestCase): def testExtractText(self): @@ -61,8 +62,8 @@ class AnchorTest(unittest.TestCase): }) import pyarrow as pa - import pyarrow.json as paj import pyarrow.compute as pc + import pyarrow.json as paj buffer = io.BytesIO(jsondata.encode('utf-8')) paj.read_json(buffer, read_options=paj.ReadOptions(use_threads=False, block_size=len(jsondata))) diff --git a/tests/test_coherency.py b/tests/test_coherency.py index d411d89..19a7444 100644 --- a/tests/test_coherency.py +++ b/tests/test_coherency.py @@ -4,11 +4,10 @@ import os import time import unittest - from olmocr.filter.coherency import get_document_coherency - from olmocr.prompts.anchor import get_anchor_text + class TestCoherencyScores(unittest.TestCase): def testBadOcr1(self): good_text = get_anchor_text( diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 21396f3..49a4092 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -1,17 +1,16 @@ import unittest -from torch.utils.data import DataLoader -from tqdm import tqdm from functools import partial +from torch.utils.data import DataLoader +from tqdm import tqdm from transformers import AutoProcessor from olmocr.train.dataloader import ( build_finetuning_dataset, extract_openai_batch_response, + list_dataset_files, load_jsonl_into_ds, - list_dataset_files ) - from olmocr.train.dataprep import batch_prepare_data_for_qwen2_training diff --git a/tests/test_dataprep.py b/tests/test_dataprep.py index 8f0d01a..73bb10a 100644 --- a/tests/test_dataprep.py +++ b/tests/test_dataprep.py @@ -1,28 +1,31 @@ -import unittest -import random -import requests import base64 -import torch import os +import random import re +import unittest from io import BytesIO -from PIL import Image -from transformers import AutoProcessor from unittest.mock import patch +import numpy as np +import requests +import torch +from PIL import Image +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoProcessor + +from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig from olmocr.train.dataloader import ( build_finetuning_dataset, ) - from olmocr.train.dataprep import ( - prepare_data_for_qwen2_training, build_finetuning_prompt, - prepare_data_for_molmo_training, batch_prepare_data_for_molmo_training + batch_prepare_data_for_molmo_training, + build_finetuning_prompt, + prepare_data_for_molmo_training, + prepare_data_for_qwen2_training, ) -import numpy as np -from tqdm import tqdm -from torch.utils.data import DataLoader from olmocr.train.utils import make_dataset -from olmocr.train.core.config import TrainConfig, DataConfig, SourceConfig + class TestDataprep(unittest.TestCase): def testFullDataloader(self): diff --git a/tests/test_molmo.py b/tests/test_molmo.py index 947a289..df61993 100644 --- a/tests/test_molmo.py +++ b/tests/test_molmo.py @@ -1,8 +1,14 @@ import unittest -from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer, GenerationConfig -from PIL import Image import requests +from PIL import Image +from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, + GenerationConfig, +) + class MolmoProcessorTest(unittest.TestCase): def test_molmo_demo(self): diff --git a/tests/test_s3_work_queue.py b/tests/test_s3_work_queue.py index 6b602a3..6ca3af8 100644 --- a/tests/test_s3_work_queue.py +++ b/tests/test_s3_work_queue.py @@ -1,14 +1,16 @@ -import unittest import asyncio import datetime -from unittest.mock import Mock, patch, call -from botocore.exceptions import ClientError import hashlib -from typing import List, Dict +import unittest +from typing import Dict, List +from unittest.mock import Mock, call, patch + +from botocore.exceptions import ClientError # Import the classes we're testing from olmocr.work_queue import S3WorkQueue, WorkItem + class TestS3WorkQueue(unittest.TestCase): def setUp(self): """Set up test fixtures before each test method.""" diff --git a/tests/test_sglang.py b/tests/test_sglang.py index 830b0b6..8ee6f04 100644 --- a/tests/test_sglang.py +++ b/tests/test_sglang.py @@ -4,24 +4,33 @@ # Compare that the temperature 0 sampled result is the same import asyncio -import unittest -from unittest.mock import patch, AsyncMock -import os -import json -import tempfile -import math import base64 -import torch -import numpy as np +import json +import math +import os +import tempfile +import unittest from io import BytesIO +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import numpy as np +import torch +import torch.nn.functional as F +from httpx import AsyncClient from PIL import Image from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration -from pathlib import Path -from olmocr.pipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text, download_directory -from olmocr.prompts import PageResponse -from httpx import AsyncClient -import torch.nn.functional as F +from olmocr.pipeline import ( + SGLANG_SERVER_PORT, + build_page_query, + download_directory, + get_anchor_text, + render_pdf_to_base64png, + sglang_server_ready, + sglang_server_task, +) +from olmocr.prompts import PageResponse MODEL_FINETUNED_PATH = "s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/jakep/Qwen_Qwen2-VL-7B-Instruct-e4ecf8-01JAH8GMWHTJ376S2N7ETXRXH4/checkpoint-9500/bf16/" @@ -320,12 +329,12 @@ class RawSGLangTest(unittest.IsolatedAsyncioTestCase): print("HF", hf_output, hf_output.shape) from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams - from sglang.srt.hf_transformers_utils import get_tokenizer - from sglang.srt.server_args import ServerArgs, PortArgs + from sglang.srt.server_args import PortArgs, ServerArgs model_config = ModelConfig( self.model_cache_dir,