mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 16:52:20 +00:00
Checking that anchor text works for each pdf page when initializing dataloader
This commit is contained in:
parent
dc7fff5bf7
commit
8e5e18f54c
@ -5,6 +5,7 @@ import random
|
||||
import re
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from os import PathLike
|
||||
from typing import List, Literal
|
||||
|
||||
import ftfy
|
||||
@ -16,7 +17,7 @@ from olmocr.filter.coherency import get_document_coherency
|
||||
|
||||
|
||||
def get_anchor_text(
|
||||
local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
|
||||
local_pdf_path: str | PathLike, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
|
||||
) -> str:
|
||||
assert page > 0, "Pages are 1-indexed in pdf-land"
|
||||
|
||||
|
@ -12,6 +12,7 @@ from omegaconf import DictConfig, OmegaConf
|
||||
@dataclass
|
||||
class PipelineStepConfig:
|
||||
"""Base configuration for pipeline steps."""
|
||||
|
||||
name: str
|
||||
enabled: bool = True
|
||||
|
||||
@ -19,6 +20,7 @@ class PipelineStepConfig:
|
||||
@dataclass
|
||||
class FrontMatterParserConfig(PipelineStepConfig):
|
||||
"""Configuration for FrontMatterParser step."""
|
||||
|
||||
name: str = "FrontMatterParser"
|
||||
use_page_response_class: bool = True # Whether to use PageResponse dataclass
|
||||
|
||||
@ -26,6 +28,7 @@ class FrontMatterParserConfig(PipelineStepConfig):
|
||||
@dataclass
|
||||
class PDFRendererConfig(PipelineStepConfig):
|
||||
"""Configuration for PDFRenderer step."""
|
||||
|
||||
name: str = "PDFRenderer"
|
||||
target_longest_image_dim: int = 1024
|
||||
|
||||
@ -33,6 +36,7 @@ class PDFRendererConfig(PipelineStepConfig):
|
||||
@dataclass
|
||||
class StaticLengthDocumentAnchoringConfig(PipelineStepConfig):
|
||||
"""Configuration for StaticLengthDocumentAnchoring step."""
|
||||
|
||||
name: str = "StaticLengthDocumentAnchoring"
|
||||
target_anchor_text_len: int = 6000
|
||||
|
||||
@ -40,24 +44,28 @@ class StaticLengthDocumentAnchoringConfig(PipelineStepConfig):
|
||||
@dataclass
|
||||
class FinetuningPromptConfig(PipelineStepConfig):
|
||||
"""Configuration for FinetuningPrompt step."""
|
||||
|
||||
name: str = "FinetuningPrompt"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontMatterOutputFormatConfig(PipelineStepConfig):
|
||||
"""Configuration for FrontMatterOutputFormat step."""
|
||||
|
||||
name: str = "FrontMatterOutputFormat"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstructUserMessagesConfig(PipelineStepConfig):
|
||||
"""Configuration for InstructUserMessages step."""
|
||||
|
||||
name: str = "InstructUserMessages"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerStepConfig(PipelineStepConfig):
|
||||
"""Configuration for Tokenizer step."""
|
||||
|
||||
name: str = "Tokenizer"
|
||||
masking_index: int = -100
|
||||
end_of_message_token: str = "<|im_end|>"
|
||||
@ -66,6 +74,7 @@ class TokenizerStepConfig(PipelineStepConfig):
|
||||
@dataclass
|
||||
class DatasetItemConfig:
|
||||
"""Configuration for a single dataset item."""
|
||||
|
||||
root_dir: str
|
||||
pipeline: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
@ -76,6 +85,7 @@ class DatasetItemConfig:
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
"""Configuration for dataset and data loading."""
|
||||
|
||||
train: List[Dict[str, Any]] = field(default_factory=list)
|
||||
eval: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
@ -83,6 +93,7 @@ class DatasetConfig:
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Configuration for model."""
|
||||
|
||||
name: str = "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
trust_remote_code: bool = False
|
||||
|
||||
@ -112,6 +123,7 @@ class ModelConfig:
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Configuration for training parameters."""
|
||||
|
||||
output_dir: str = "./outputs"
|
||||
num_train_epochs: int = 3
|
||||
per_device_train_batch_size: int = 1
|
||||
@ -181,6 +193,7 @@ class TrainingConfig:
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Main configuration class that combines all sub-configs."""
|
||||
|
||||
model: ModelConfig
|
||||
dataset: DatasetConfig
|
||||
training: TrainingConfig
|
||||
@ -208,7 +221,7 @@ class Config:
|
||||
raise FileNotFoundError(f"Config file not found: {yaml_path}")
|
||||
|
||||
# Load YAML with OmegaConf for better features
|
||||
with open(yaml_path, 'r') as f:
|
||||
with open(yaml_path, "r") as f:
|
||||
yaml_content = yaml.safe_load(f)
|
||||
|
||||
# Create OmegaConf config for interpolation and validation
|
||||
@ -221,20 +234,14 @@ class Config:
|
||||
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
|
||||
|
||||
# Create sub-configs
|
||||
model_cfg = ModelConfig(**cfg_dict.get('model', {}))
|
||||
dataset_cfg = DatasetConfig(**cfg_dict.get('dataset', {}))
|
||||
training_cfg = TrainingConfig(**cfg_dict.get('training', {}))
|
||||
model_cfg = ModelConfig(**cfg_dict.get("model", {}))
|
||||
dataset_cfg = DatasetConfig(**cfg_dict.get("dataset", {}))
|
||||
training_cfg = TrainingConfig(**cfg_dict.get("training", {}))
|
||||
|
||||
# Create main config
|
||||
main_cfg_dict = {k: v for k, v in cfg_dict.items()
|
||||
if k not in ['model', 'dataset', 'training']}
|
||||
main_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in ["model", "dataset", "training"]}
|
||||
|
||||
return cls(
|
||||
model=model_cfg,
|
||||
dataset=dataset_cfg,
|
||||
training=training_cfg,
|
||||
**main_cfg_dict
|
||||
)
|
||||
return cls(model=model_cfg, dataset=dataset_cfg, training=training_cfg, **main_cfg_dict)
|
||||
|
||||
def to_yaml(self, yaml_path: Union[str, Path]) -> None:
|
||||
"""Save configuration to YAML file."""
|
||||
@ -244,7 +251,7 @@ class Config:
|
||||
# Convert to OmegaConf for nice YAML output
|
||||
cfg = OmegaConf.structured(self)
|
||||
|
||||
with open(yaml_path, 'w') as f:
|
||||
with open(yaml_path, "w") as f:
|
||||
OmegaConf.save(cfg, f)
|
||||
|
||||
def validate(self) -> None:
|
||||
@ -252,7 +259,7 @@ class Config:
|
||||
# Dataset validation - check all train and eval datasets
|
||||
for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]:
|
||||
for i, dataset_cfg in enumerate(datasets):
|
||||
root_dir = dataset_cfg.get('root_dir')
|
||||
root_dir = dataset_cfg.get("root_dir")
|
||||
if not root_dir:
|
||||
raise ValueError(f"Missing root_dir for {split_name} dataset {i}")
|
||||
if not os.path.exists(root_dir):
|
||||
@ -284,60 +291,59 @@ class Config:
|
||||
Returns:
|
||||
List of initialized pipeline step instances
|
||||
"""
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
from olmocr.train.dataloader import (
|
||||
FrontMatterParser,
|
||||
PDFRenderer,
|
||||
StaticLengthDocumentAnchoring,
|
||||
FinetuningPrompt,
|
||||
FrontMatterOutputFormat,
|
||||
FrontMatterParser,
|
||||
InstructUserMessages,
|
||||
Tokenizer
|
||||
PDFRenderer,
|
||||
StaticLengthDocumentAnchoring,
|
||||
Tokenizer,
|
||||
)
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
|
||||
steps = []
|
||||
for step_config in pipeline_config:
|
||||
if not step_config.get('enabled', True):
|
||||
if not step_config.get("enabled", True):
|
||||
continue
|
||||
|
||||
step_name = step_config['name']
|
||||
step_name = step_config["name"]
|
||||
|
||||
if step_name == 'FrontMatterParser':
|
||||
if step_name == "FrontMatterParser":
|
||||
# Handle both old and new config format
|
||||
if 'front_matter_class' in step_config:
|
||||
front_matter_class = PageResponse if step_config['front_matter_class'] == 'PageResponse' else None
|
||||
if "front_matter_class" in step_config:
|
||||
front_matter_class = PageResponse if step_config["front_matter_class"] == "PageResponse" else None
|
||||
else:
|
||||
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None
|
||||
front_matter_class = PageResponse if step_config.get("use_page_response_class", True) else None
|
||||
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
|
||||
|
||||
elif step_name == 'PDFRenderer':
|
||||
steps.append(PDFRenderer(
|
||||
target_longest_image_dim=step_config.get('target_longest_image_dim', 1024),
|
||||
image_transform=None # Can be extended later
|
||||
))
|
||||
elif step_name == "PDFRenderer":
|
||||
steps.append(
|
||||
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024), image_transform=None) # Can be extended later
|
||||
)
|
||||
|
||||
elif step_name == 'StaticLengthDocumentAnchoring':
|
||||
steps.append(StaticLengthDocumentAnchoring(
|
||||
target_anchor_text_len=step_config.get('target_anchor_text_len', 6000)
|
||||
))
|
||||
elif step_name == "StaticLengthDocumentAnchoring":
|
||||
steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
|
||||
|
||||
elif step_name == 'FinetuningPrompt':
|
||||
elif step_name == "FinetuningPrompt":
|
||||
steps.append(FinetuningPrompt())
|
||||
|
||||
elif step_name == 'FrontMatterOutputFormat':
|
||||
elif step_name == "FrontMatterOutputFormat":
|
||||
steps.append(FrontMatterOutputFormat())
|
||||
|
||||
elif step_name == 'InstructUserMessages':
|
||||
elif step_name == "InstructUserMessages":
|
||||
steps.append(InstructUserMessages())
|
||||
|
||||
elif step_name == 'Tokenizer':
|
||||
elif step_name == "Tokenizer":
|
||||
if processor is None:
|
||||
raise ValueError("Processor must be provided for Tokenizer step")
|
||||
steps.append(Tokenizer(
|
||||
steps.append(
|
||||
Tokenizer(
|
||||
processor=processor,
|
||||
masking_index=step_config.get('masking_index', -100),
|
||||
end_of_message_token=step_config.get('end_of_message_token', '<|im_end|>')
|
||||
))
|
||||
masking_index=step_config.get("masking_index", -100),
|
||||
end_of_message_token=step_config.get("end_of_message_token", "<|im_end|>"),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown pipeline step: {step_name}")
|
||||
|
||||
@ -346,11 +352,7 @@ class Config:
|
||||
|
||||
def create_default_config() -> Config:
|
||||
"""Create a default configuration."""
|
||||
return Config(
|
||||
model=ModelConfig(),
|
||||
dataset=DatasetConfig(),
|
||||
training=TrainingConfig()
|
||||
)
|
||||
return Config(model=ModelConfig(), dataset=DatasetConfig(), training=TrainingConfig())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -71,7 +71,7 @@ class BaseMarkdownPDFDataset(Dataset):
|
||||
|
||||
# Verify the resolved path exists
|
||||
if pdf_path.exists():
|
||||
# Validate PDF - check it loads and has exactly one page
|
||||
# Validate PDF - check it loads and has exactly one page and that you can get document-anchoring from it
|
||||
try:
|
||||
reader = PdfReader(str(pdf_path))
|
||||
num_pages = len(reader.pages)
|
||||
@ -80,6 +80,9 @@ class BaseMarkdownPDFDataset(Dataset):
|
||||
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
|
||||
continue
|
||||
|
||||
# Test that document anchoring works
|
||||
get_anchor_text(pdf_path, page=1, pdf_engine="pdfreport", target_length=100)
|
||||
|
||||
self.samples.append({"markdown_path": md_path, "pdf_path": pdf_path})
|
||||
valid_count += 1
|
||||
|
||||
|
@ -4,18 +4,18 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
EarlyStoppingCallback
|
||||
)
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
EarlyStoppingCallback,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from olmocr.train.config import Config
|
||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||
@ -34,53 +34,44 @@ class QwenDataCollator:
|
||||
|
||||
def __call__(self, examples):
|
||||
# Filter out None values and extract the fields we need
|
||||
batch = {
|
||||
'input_ids': [],
|
||||
'attention_mask': [],
|
||||
'labels': [],
|
||||
'pixel_values': [],
|
||||
'image_grid_thw': []
|
||||
}
|
||||
batch = {"input_ids": [], "attention_mask": [], "labels": [], "pixel_values": [], "image_grid_thw": []}
|
||||
|
||||
for example in examples:
|
||||
if example is not None:
|
||||
# Convert numpy arrays to tensors
|
||||
batch['input_ids'].append(torch.from_numpy(example['input_ids']) if isinstance(example['input_ids'], np.ndarray) else example['input_ids'])
|
||||
batch['attention_mask'].append(torch.from_numpy(example['attention_mask']) if isinstance(example['attention_mask'], np.ndarray) else example['attention_mask'])
|
||||
batch['labels'].append(torch.from_numpy(example['labels']) if isinstance(example['labels'], np.ndarray) else example['labels'])
|
||||
batch["input_ids"].append(torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"])
|
||||
batch["attention_mask"].append(
|
||||
torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"]
|
||||
)
|
||||
batch["labels"].append(torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"])
|
||||
|
||||
# Handle pixel_values which might be numpy array or already a tensor
|
||||
pixel_values = example['pixel_values']
|
||||
pixel_values = example["pixel_values"]
|
||||
if isinstance(pixel_values, np.ndarray):
|
||||
pixel_values = torch.from_numpy(pixel_values)
|
||||
batch['pixel_values'].append(pixel_values)
|
||||
batch["pixel_values"].append(pixel_values)
|
||||
|
||||
# Handle image_grid_thw
|
||||
image_grid_thw = example['image_grid_thw']
|
||||
image_grid_thw = example["image_grid_thw"]
|
||||
if isinstance(image_grid_thw, np.ndarray):
|
||||
image_grid_thw = torch.from_numpy(image_grid_thw)
|
||||
batch['image_grid_thw'].append(image_grid_thw)
|
||||
batch["image_grid_thw"].append(image_grid_thw)
|
||||
|
||||
# Convert lists to tensors with proper padding
|
||||
# Note: For Qwen2-VL, we typically handle variable length sequences
|
||||
# The model's processor should handle the padding internally
|
||||
return {
|
||||
'input_ids': torch.stack(batch['input_ids']),
|
||||
'attention_mask': torch.stack(batch['attention_mask']),
|
||||
'labels': torch.stack(batch['labels']),
|
||||
'pixel_values': torch.stack(batch['pixel_values']), # Stack into tensor
|
||||
'image_grid_thw': torch.stack(batch['image_grid_thw'])
|
||||
"input_ids": torch.stack(batch["input_ids"]),
|
||||
"attention_mask": torch.stack(batch["attention_mask"]),
|
||||
"labels": torch.stack(batch["labels"]),
|
||||
"pixel_values": torch.stack(batch["pixel_values"]), # Stack into tensor
|
||||
"image_grid_thw": torch.stack(batch["image_grid_thw"]),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="olmocr/train/configs/example_config.yaml",
|
||||
help="Path to YAML configuration file"
|
||||
)
|
||||
parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -130,8 +121,8 @@ def main():
|
||||
logger.info("Creating training datasets...")
|
||||
train_datasets = []
|
||||
for i, dataset_cfg in enumerate(config.dataset.train):
|
||||
root_dir = dataset_cfg['root_dir']
|
||||
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
||||
root_dir = dataset_cfg["root_dir"]
|
||||
pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
|
||||
|
||||
logger.info(f"Creating training dataset {i+1} from: {root_dir}")
|
||||
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
||||
@ -148,11 +139,11 @@ def main():
|
||||
logger.info("Creating evaluation datasets...")
|
||||
eval_datasets = {}
|
||||
for i, dataset_cfg in enumerate(config.dataset.eval):
|
||||
root_dir = dataset_cfg['root_dir']
|
||||
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
||||
root_dir = dataset_cfg["root_dir"]
|
||||
pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
|
||||
|
||||
# Use dataset name if provided, otherwise use root_dir as name
|
||||
dataset_name = dataset_cfg.get('name', f"eval_dataset_{i+1}")
|
||||
dataset_name = dataset_cfg.get("name", f"eval_dataset_{i+1}")
|
||||
|
||||
logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}")
|
||||
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
||||
@ -214,8 +205,7 @@ def main():
|
||||
if config.training.use_early_stopping:
|
||||
callbacks.append(
|
||||
EarlyStoppingCallback(
|
||||
early_stopping_patience=config.training.early_stopping_patience,
|
||||
early_stopping_threshold=config.training.early_stopping_threshold
|
||||
early_stopping_patience=config.training.early_stopping_patience, early_stopping_threshold=config.training.early_stopping_threshold
|
||||
)
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user