mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-25 06:06:23 +00:00
Ruff
This commit is contained in:
parent
fb402297ce
commit
56903774b7
@ -12,7 +12,7 @@ def check_poppler_version():
|
||||
if result.returncode == 0 and result.stderr.startswith("pdftoppm"):
|
||||
logger.info("pdftoppm is installed and working.")
|
||||
else:
|
||||
logger.error(f"pdftoppm is installed but returned an error.")
|
||||
logger.error("pdftoppm is installed but returned an error.")
|
||||
sys.exit(1)
|
||||
except FileNotFoundError:
|
||||
logger.error("pdftoppm is not installed.")
|
||||
@ -22,7 +22,7 @@ def check_poppler_version():
|
||||
|
||||
def check_sglang_version():
|
||||
if importlib.util.find_spec("sglang") is None:
|
||||
logger.error(f"Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
|
||||
logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
|
||||
logger.error("Sglang needs to be installed with a separate command in order to find all dependencies properly.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import argparse
|
||||
import base64
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from typing import Generator
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ from pathlib import Path
|
||||
import smart_open
|
||||
from cached_path import cached_path
|
||||
|
||||
from olmocr.prompts import build_finetuning_prompt
|
||||
|
||||
|
||||
def setup_logging():
|
||||
@ -66,7 +65,6 @@ def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
|
||||
local_pdf_path = cached_path(s3_path, quiet=True)
|
||||
|
||||
from olmocr.data.buildsilver import build_page_query
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
|
||||
obj = build_page_query(local_pdf_path, s3_path, page)
|
||||
# raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
||||
|
||||
@ -3,7 +3,6 @@ import io
|
||||
import subprocess
|
||||
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader
|
||||
|
||||
|
||||
def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[float, float]:
|
||||
|
||||
@ -6,8 +6,6 @@ import datetime
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from enum import Enum
|
||||
|
||||
from openai import OpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -13,7 +13,7 @@ from dolma_refine.evaluate.segmenters import SpacySegmenter
|
||||
from tqdm import tqdm
|
||||
|
||||
from olmocr.eval.evalhtml import create_review_html
|
||||
from olmocr.s3_utils import expand_s3_glob, get_s3_bytes, parse_s3_path
|
||||
from olmocr.s3_utils import expand_s3_glob, get_s3_bytes
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
@ -13,7 +13,7 @@ import sys
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Optional
|
||||
|
||||
import boto3
|
||||
import zstandard
|
||||
|
||||
@ -5,7 +5,6 @@ from collections import Counter
|
||||
|
||||
from lingua import Language, LanguageDetectorBuilder
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import DependencyError, PyPdfError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
|
||||
|
||||
@ -12,8 +12,6 @@ import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
@ -22,7 +20,6 @@ from concurrent.futures.process import BrokenProcessPool
|
||||
from dataclasses import dataclass
|
||||
from functools import cache, partial
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
@ -44,13 +41,11 @@ from olmocr.metrics import MetricsKeeper, WorkerTracker
|
||||
from olmocr.prompts import PageResponse, build_finetuning_prompt
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.s3_utils import (
|
||||
download_directory,
|
||||
download_zstd_csv,
|
||||
expand_s3_glob,
|
||||
get_s3_bytes,
|
||||
get_s3_bytes_with_backoff,
|
||||
parse_s3_path,
|
||||
upload_zstd_csv,
|
||||
)
|
||||
from olmocr.version import VERSION
|
||||
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
|
||||
@ -245,7 +240,7 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
|
||||
if base_response_data["usage"]["total_tokens"] > args.model_max_context:
|
||||
local_anchor_text_len = max(1, local_anchor_text_len // 2)
|
||||
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_orig_path}-{page_num}")
|
||||
raise ValueError(f"Response exceeded model_max_context, cannot use this response")
|
||||
raise ValueError("Response exceeded model_max_context, cannot use this response")
|
||||
|
||||
metrics.add_metrics(
|
||||
sglang_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
|
||||
@ -627,8 +622,8 @@ async def sglang_server_host(args, semaphore):
|
||||
|
||||
if retry >= MAX_RETRIES:
|
||||
logger.error(f"Ended up starting the sglang server more than {retry} times, cancelling pipeline")
|
||||
logger.error(f"")
|
||||
logger.error(f"Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
|
||||
logger.error("")
|
||||
logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@ -668,8 +663,6 @@ def submit_beaker_job(args):
|
||||
from beaker import (
|
||||
Beaker,
|
||||
Constraints,
|
||||
DataMount,
|
||||
DataSource,
|
||||
EnvVar,
|
||||
ExperimentSpec,
|
||||
ImageSource,
|
||||
@ -712,7 +705,7 @@ def submit_beaker_job(args):
|
||||
b.secret.write(f"{owner}-AWS_CREDENTIALS_FILE", open(os.path.join(os.path.expanduser("~"), ".aws", "credentials")).read(), args.beaker_workspace)
|
||||
|
||||
try:
|
||||
b.secret.get(f"OE_DATA_GCS_SA_KEY", args.beaker_workspace)
|
||||
b.secret.get("OE_DATA_GCS_SA_KEY", args.beaker_workspace)
|
||||
except SecretNotFound:
|
||||
print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):")
|
||||
lines = []
|
||||
@ -724,7 +717,7 @@ def submit_beaker_job(args):
|
||||
lines.append(line)
|
||||
gcs_sa_key = "\n".join(lines[:-1]).strip() # Remove the last empty line
|
||||
if gcs_sa_key:
|
||||
b.secret.write(f"OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
|
||||
b.secret.write("OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
|
||||
|
||||
# Create the experiment spec
|
||||
experiment_spec = ExperimentSpec(
|
||||
@ -748,7 +741,7 @@ def submit_beaker_job(args):
|
||||
EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
|
||||
EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
|
||||
EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
|
||||
EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret=f"OE_DATA_GCS_SA_KEY"),
|
||||
EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"),
|
||||
],
|
||||
resources=TaskResources(gpu_count=1),
|
||||
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
|
||||
@ -860,12 +853,12 @@ def print_stats(args):
|
||||
|
||||
skipped_paths = original_paths - all_processed_paths
|
||||
|
||||
print(f"\nWork Items Status:")
|
||||
print("\nWork Items Status:")
|
||||
print(f"Total work items: {total_items:,}")
|
||||
print(f"Completed items: {completed_items:,}")
|
||||
print(f"Remaining items: {total_items - completed_items:,}")
|
||||
|
||||
print(f"\nResults:")
|
||||
print("\nResults:")
|
||||
print(f"Total documents processed: {docs_total:,}")
|
||||
print(f"Total documents skipped: {len(skipped_paths):,}")
|
||||
print(f"Total pages on fallback: {fallback_pages_total:,}")
|
||||
|
||||
@ -3,23 +3,15 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pypdf._cmap import build_char_map, unknown_char_map
|
||||
from pypdf.constants import AnnotationDictionaryAttributes as ADA
|
||||
from pypdf.constants import ImageAttributes as IA
|
||||
from pypdf.constants import PageAttributes as PG
|
||||
from pypdf.constants import Resources as RES
|
||||
from pypdf.generic import (
|
||||
ContentStream,
|
||||
DictionaryObject,
|
||||
|
||||
@ -13,7 +13,6 @@ import re
|
||||
# coherency score best of these three
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Literal
|
||||
|
||||
import ftfy
|
||||
|
||||
@ -5,7 +5,6 @@ import hashlib
|
||||
import logging
|
||||
import os
|
||||
import posixpath
|
||||
import tempfile
|
||||
import time
|
||||
from io import BytesIO, TextIOWrapper
|
||||
from pathlib import Path
|
||||
@ -17,8 +16,7 @@ import requests
|
||||
import zstandard as zstd
|
||||
from boto3.s3.transfer import TransferConfig
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import ClientError, NoCredentialsError
|
||||
from google.auth import compute_engine
|
||||
from botocore.exceptions import ClientError
|
||||
from google.cloud import storage
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import NoCredentialsError, PartialCredentialsError
|
||||
|
||||
@ -1,32 +1,20 @@
|
||||
import base64
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from logging import Logger
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import boto3
|
||||
import pypdf
|
||||
import pypdf.errors
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
Features,
|
||||
Value,
|
||||
concatenate_datasets,
|
||||
load_dataset,
|
||||
)
|
||||
from filelock import FileLock
|
||||
|
||||
from olmocr.data.renderpdf import get_pdf_media_box_width_height
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.s3_utils import get_s3_bytes, parse_custom_id, parse_s3_path
|
||||
from olmocr.s3_utils import parse_custom_id, parse_s3_path
|
||||
|
||||
from .core.config import DataConfig, SourceConfig
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@ -2,7 +2,6 @@ import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import boto3
|
||||
import torch
|
||||
|
||||
@ -1,35 +1,18 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoProcessor,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.prompts.anchor import get_anchor_text
|
||||
from olmocr.prompts.prompts import (
|
||||
build_finetuning_prompt,
|
||||
build_openai_silver_data_prompt,
|
||||
)
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoProcessor, DataCollatorForSeq2Seq
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from olmocr.train.core.cli import make_cli
|
||||
from olmocr.train.core.config import TrainConfig
|
||||
|
||||
from .utils import TruncatingCollator, make_dataset
|
||||
from .utils import make_dataset
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class MolmoConfig(PretrainedConfig):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Image processor class for Molmo"""
|
||||
|
||||
from typing import List, Mapping, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@ -13,7 +13,6 @@ from transformers.image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ImageInput,
|
||||
is_valid_image,
|
||||
)
|
||||
from transformers.processing_utils import ImagesKwargs
|
||||
from transformers.utils import logging
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from dataclasses import dataclass, replace
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
@ -17,7 +17,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import torch
|
||||
from einops import einops, einsum
|
||||
from einops import einops
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers import GenerationConfig, PreTrainedModel
|
||||
@ -1809,7 +1809,7 @@ class Molmo(nn.Module):
|
||||
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
|
||||
attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)
|
||||
if position_ids is None:
|
||||
raise ValueError(f"Positioned ids must be given if using subsegment_ids")
|
||||
raise ValueError("Positioned ids must be given if using subsegment_ids")
|
||||
else:
|
||||
if self.config.use_position_ids and position_ids is None:
|
||||
position_ids = torch.clamp(
|
||||
|
||||
@ -4,7 +4,6 @@ Processor class for Molmo.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import PIL
|
||||
from PIL import ImageOps
|
||||
from PIL.Image import Image
|
||||
|
||||
@ -30,10 +29,10 @@ from .image_preprocessing_molmo import MolmoImageProcessor, MolmoImagesKwargs
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = f"<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = f"<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = f"<im_end>"
|
||||
DEFAULT_IM_COL_TOKEN = f"<im_col>"
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = "<im_end>"
|
||||
DEFAULT_IM_COL_TOKEN = "<im_col>"
|
||||
IMAGE_PROMPT = "<|image|>"
|
||||
|
||||
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
|
||||
|
||||
@ -1,30 +1,17 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
import torch.distributed
|
||||
import wandb
|
||||
from datasets import DatasetDict, concatenate_datasets
|
||||
from datasets.utils import disable_progress_bars
|
||||
from datasets.utils.logging import set_verbosity
|
||||
from peft import LoraConfig, get_peft_model # pyright: ignore
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Trainer,
|
||||
|
||||
@ -12,10 +12,9 @@ from tempfile import TemporaryDirectory
|
||||
from typing import Dict, Generator, List, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import PrecisionType
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from .core.cli import to_native_types
|
||||
@ -28,7 +27,7 @@ from .core.state import BeakerState
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
from olmocr.train.dataloader import build_finetuning_dataset, list_dataset_files
|
||||
from olmocr.train.dataloader import build_finetuning_dataset
|
||||
from olmocr.train.dataprep import (
|
||||
batch_prepare_data_for_molmo_training,
|
||||
batch_prepare_data_for_qwen2_training,
|
||||
|
||||
@ -5,10 +5,8 @@ import hashlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user