This commit is contained in:
Jake Poznanski 2025-01-29 15:47:57 -08:00
parent fb402297ce
commit 56903774b7
25 changed files with 30 additions and 109 deletions

View File

@ -12,7 +12,7 @@ def check_poppler_version():
if result.returncode == 0 and result.stderr.startswith("pdftoppm"):
logger.info("pdftoppm is installed and working.")
else:
logger.error(f"pdftoppm is installed but returned an error.")
logger.error("pdftoppm is installed but returned an error.")
sys.exit(1)
except FileNotFoundError:
logger.error("pdftoppm is not installed.")
@ -22,7 +22,7 @@ def check_poppler_version():
def check_sglang_version():
if importlib.util.find_spec("sglang") is None:
logger.error(f"Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
logger.error("Sglang needs to be installed with a separate command in order to find all dependencies properly.")
sys.exit(1)

View File

@ -1,11 +1,9 @@
import argparse
import base64
import glob
import json
import os
import random
import subprocess
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Generator
from urllib.parse import urlparse

View File

@ -10,7 +10,6 @@ from pathlib import Path
import smart_open
from cached_path import cached_path
from olmocr.prompts import build_finetuning_prompt
def setup_logging():
@ -66,7 +65,6 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
local_pdf_path = cached_path(s3_path, quiet=True)
from olmocr.data.buildsilver import build_page_query
from olmocr.prompts.anchor import get_anchor_text
obj = build_page_query(local_pdf_path, s3_path, page)
# raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")

View File

@ -3,7 +3,6 @@ import io
import subprocess
from PIL import Image
from pypdf import PdfReader
def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[float, float]:

View File

@ -6,8 +6,6 @@ import datetime
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from openai import OpenAI
from tqdm import tqdm

View File

@ -13,7 +13,7 @@ from dolma_refine.evaluate.segmenters import SpacySegmenter
from tqdm import tqdm
from olmocr.eval.evalhtml import create_review_html
from olmocr.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path
from olmocr.s3_utils import expand_s3_glob, get_s3_bytes
@dataclasses.dataclass

View File

@ -13,7 +13,7 @@ import sys
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple
from typing import Dict, Optional
import boto3
import zstandard

View File

@ -5,7 +5,6 @@ from collections import Counter
from lingua import Language, LanguageDetectorBuilder
from pypdf import PdfReader
from pypdf.errors import DependencyError, PyPdfError
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

View File

@ -1,7 +1,6 @@
import asyncio
import time
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Dict

View File

@ -12,8 +12,6 @@ import os
import random
import re
import shutil
import signal
import subprocess
import sys
import tempfile
import time
@ -22,7 +20,6 @@ from concurrent.futures.process import BrokenProcessPool
from dataclasses import dataclass
from functools import cache, partial
from io import BytesIO
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse
import boto3
@ -44,13 +41,11 @@ from olmocr.metrics import MetricsKeeper, WorkerTracker
from olmocr.prompts import PageResponse, build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
from olmocr.s3_utils import (
download_directory,
download_zstd_csv,
expand_s3_glob,
get_s3_bytes,
get_s3_bytes_with_backoff,
parse_s3_path,
upload_zstd_csv,
)
from olmocr.version import VERSION
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
@ -245,7 +240,7 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
if base_response_data["usage"]["total_tokens"] > args.model_max_context:
local_anchor_text_len = max(1, local_anchor_text_len // 2)
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_orig_path}-{page_num}")
raise ValueError(f"Response exceeded model_max_context, cannot use this response")
raise ValueError("Response exceeded model_max_context, cannot use this response")
metrics.add_metrics(
sglang_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
@ -627,8 +622,8 @@ async def sglang_server_host(args, semaphore):
if retry >= MAX_RETRIES:
logger.error(f"Ended up starting the sglang server more than {retry} times, cancelling pipeline")
logger.error(f"")
logger.error(f"Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
logger.error("")
logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
sys.exit(1)
@ -668,8 +663,6 @@ def submit_beaker_job(args):
from beaker import (
Beaker,
Constraints,
DataMount,
DataSource,
EnvVar,
ExperimentSpec,
ImageSource,
@ -712,7 +705,7 @@ def submit_beaker_job(args):
b.secret.write(f"{owner}-AWS_CREDENTIALS_FILE", open(os.path.join(os.path.expanduser("~"), ".aws", "credentials")).read(), args.beaker_workspace)
try:
b.secret.get(f"OE_DATA_GCS_SA_KEY", args.beaker_workspace)
b.secret.get("OE_DATA_GCS_SA_KEY", args.beaker_workspace)
except SecretNotFound:
print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):")
lines = []
@ -724,7 +717,7 @@ def submit_beaker_job(args):
lines.append(line)
gcs_sa_key = "\n".join(lines[:-1]).strip() # Remove the last empty line
if gcs_sa_key:
b.secret.write(f"OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
b.secret.write("OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
# Create the experiment spec
experiment_spec = ExperimentSpec(
@ -748,7 +741,7 @@ def submit_beaker_job(args):
EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret=f"OE_DATA_GCS_SA_KEY"),
EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"),
],
resources=TaskResources(gpu_count=1),
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
@ -860,12 +853,12 @@ def print_stats(args):
skipped_paths = original_paths - all_processed_paths
print(f"\nWork Items Status:")
print("\nWork Items Status:")
print(f"Total work items: {total_items:,}")
print(f"Completed items: {completed_items:,}")
print(f"Remaining items: {total_items - completed_items:,}")
print(f"\nResults:")
print("\nResults:")
print(f"Total documents processed: {docs_total:,}")
print(f"Total documents skipped: {len(skipped_paths):,}")
print(f"Total pages on fallback: {fallback_pages_total:,}")

View File

@ -3,23 +3,15 @@ from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
overload,
)
from pypdf._cmap import build_char_map, unknown_char_map
from pypdf.constants import AnnotationDictionaryAttributes as ADA
from pypdf.constants import ImageAttributes as IA
from pypdf.constants import PageAttributes as PG
from pypdf.constants import Resources as RES
from pypdf.generic import (
ContentStream,
DictionaryObject,

View File

@ -13,7 +13,6 @@ import re
# coherency score best of these three
import subprocess
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Literal
import ftfy

View File

@ -5,7 +5,6 @@ import hashlib
import logging
import os
import posixpath
import tempfile
import time
from io import BytesIO, TextIOWrapper
from pathlib import Path
@ -17,8 +16,7 @@ import requests
import zstandard as zstd
from boto3.s3.transfer import TransferConfig
from botocore.config import Config
from botocore.exceptions import ClientError, NoCredentialsError
from google.auth import compute_engine
from botocore.exceptions import ClientError
from google.cloud import storage
from tqdm import tqdm

View File

@ -1,7 +1,6 @@
import argparse
import logging
import os
from functools import partial
import boto3
from botocore.exceptions import NoCredentialsError, PartialCredentialsError

View File

@ -1,32 +1,20 @@
import base64
import glob
import json
import logging
import os
import re
import tempfile
from functools import partial
from logging import Logger
from typing import Any, Dict, Optional
from typing import Optional
import boto3
import pypdf
import pypdf.errors
from datasets import (
Dataset,
DatasetDict,
Features,
Value,
concatenate_datasets,
load_dataset,
)
from filelock import FileLock
from olmocr.data.renderpdf import get_pdf_media_box_width_height
from olmocr.prompts.anchor import get_anchor_text
from olmocr.s3_utils import get_s3_bytes, parse_custom_id, parse_s3_path
from olmocr.s3_utils import parse_custom_id, parse_s3_path
from .core.config import DataConfig, SourceConfig
# Configure logging
logging.basicConfig(level=logging.INFO)

View File

@ -2,7 +2,6 @@ import argparse
import concurrent.futures
import json
import os
import tempfile
import boto3
import torch

View File

@ -1,35 +1,18 @@
import base64
import json
import logging
import os
import time
from functools import partial
from io import BytesIO
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
import accelerate
import torch
import torch.distributed
from PIL import Image
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoProcessor,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
Trainer,
TrainerCallback,
TrainingArguments,
)
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import (
build_finetuning_prompt,
build_openai_silver_data_prompt,
)

View File

@ -1,11 +1,9 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor, DataCollatorForSeq2Seq
from transformers import AutoProcessor
from olmocr.train.core.cli import make_cli
from olmocr.train.core.config import TrainConfig
from .utils import TruncatingCollator, make_dataset
from .utils import make_dataset
def main():

View File

@ -1,6 +1,5 @@
from typing import List
from transformers import AutoTokenizer, PretrainedConfig
from transformers import PretrainedConfig
class MolmoConfig(PretrainedConfig):

View File

@ -1,6 +1,6 @@
"""Image processor class for Molmo"""
from typing import List, Mapping, Optional, Union
from typing import List, Optional, Union
import einops
import numpy as np
@ -13,7 +13,6 @@ from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ImageInput,
is_valid_image,
)
from transformers.processing_utils import ImagesKwargs
from transformers.utils import logging

View File

@ -1,7 +1,7 @@
import logging
import math
from copy import deepcopy
from dataclasses import dataclass, fields, replace
from dataclasses import dataclass, replace
from enum import Enum
from typing import (
Any,
@ -17,7 +17,7 @@ from typing import (
)
import torch
from einops import einops, einsum
from einops import einops
from torch import nn
from torch.nn import functional as F
from transformers import GenerationConfig, PreTrainedModel
@ -1809,7 +1809,7 @@ class Molmo(nn.Module):
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)
if position_ids is None:
raise ValueError(f"Positioned ids must be given if using subsegment_ids")
raise ValueError("Positioned ids must be given if using subsegment_ids")
else:
if self.config.use_position_ids and position_ids is None:
position_ids = torch.clamp(

View File

@ -4,7 +4,6 @@ Processor class for Molmo.
from typing import Optional
import PIL
from PIL import ImageOps
from PIL.Image import Image
@ -30,10 +29,10 @@ from .image_preprocessing_molmo import MolmoImageProcessor, MolmoImagesKwargs
logger = logging.get_logger(__name__)
DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
DEFAULT_IM_START_TOKEN = f"<im_start>"
DEFAULT_IM_END_TOKEN = f"<im_end>"
DEFAULT_IM_COL_TOKEN = f"<im_col>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_IM_COL_TOKEN = "<im_col>"
IMAGE_PROMPT = "<|image|>"
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)

View File

@ -1,30 +1,17 @@
import base64
import json
import logging
import os
import random
import time
from functools import partial
from io import BytesIO
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
import accelerate
import torch
import torch.distributed
import wandb
from datasets import DatasetDict, concatenate_datasets
from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoProcessor,
Qwen2VLForConditionalGeneration,
Trainer,

View File

@ -12,10 +12,9 @@ from tempfile import TemporaryDirectory
from typing import Dict, Generator, List, Optional, TypeVar
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import PrecisionType
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
from datasets import Dataset, DatasetDict, concatenate_datasets
from transformers import AutoProcessor
from .core.cli import to_native_types
@ -28,7 +27,7 @@ from .core.state import BeakerState
T = TypeVar("T")
from olmocr.train.dataloader import build_finetuning_dataset, list_dataset_files
from olmocr.train.dataloader import build_finetuning_dataset
from olmocr.train.dataprep import (
batch_prepare_data_for_molmo_training,
batch_prepare_data_for_qwen2_training,

View File

@ -5,10 +5,8 @@ import hashlib
import logging
import os
import random
import tempfile
from dataclasses import dataclass
from functools import partial
from typing import Dict, List, Optional, Set
from typing import List, Optional
logger = logging.getLogger(__name__)