Ok, rotation augmentation is in

This commit is contained in:
Jake Poznanski 2025-08-04 21:15:36 +00:00
parent 3bc2c0b8e3
commit 0792c03a9a
2 changed files with 123 additions and 0 deletions

View File

@ -100,6 +100,30 @@ class TokenizerStepConfig(PipelineStepConfig):
end_of_message_token: str = "<|im_end|>"
@dataclass
class RandomTokenFlipperConfig(PipelineStepConfig):
"""Configuration for RandomTokenFlipper step."""
name: str = "RandomTokenFlipper"
token_flip_rate: float = 1e-4
masking_index: int = -100
@dataclass
class FilterOutRotatedDocumentsConfig(PipelineStepConfig):
"""Configuration for FilterOutRotatedDocuments step."""
name: str = "FilterOutRotatedDocuments"
@dataclass
class RotationAugmentationConfig(PipelineStepConfig):
"""Configuration for RotationAugmentation step."""
name: str = "RotationAugmentation"
probability: float = 0.5
@dataclass
class DatasetItemConfig:
"""Configuration for a single dataset item."""
@ -329,6 +353,7 @@ class Config:
"""
from olmocr.prompts.prompts import PageResponse
from olmocr.train.dataloader import (
FilterOutRotatedDocuments,
FinetuningPrompt,
FrontMatterOutputFormat,
FrontMatterParser,
@ -339,6 +364,7 @@ class Config:
NewYamlFinetuningPromptWithNoAnchoring,
PDFRenderer,
RandomTokenFlipper,
RotationAugmentation,
StaticLengthDocumentAnchoring,
Tokenizer,
)
@ -422,6 +448,17 @@ class Config:
masking_index=step_config.get("masking_index", -100),
)
)
elif step_name == "FilterOutRotatedDocuments":
steps.append(FilterOutRotatedDocuments())
elif step_name == "RotationAugmentation":
steps.append(
RotationAugmentation(
probability=step_config.get("probability", 0.5)
)
)
else:
raise ValueError(f"Unknown pipeline step: {step_name}")

View File

@ -438,6 +438,92 @@ class LatexBracketNormalizer(PipelineStep):
return sample
@dataclass(frozen=True, slots=True)
class RotationAugmentation(PipelineStep):
"""Pipeline step that randomly rotates images for augmentation."""
probability: float = 0.5 # Probability of applying rotation
def __call__(self, sample: Sample) -> Optional[Sample]:
"""Randomly rotate image and update rotation metadata."""
# Only proceed with given probability
if np.random.random() > self.probability:
return sample
# Check if image exists
if "image" not in sample:
return sample
# Check if page_data exists (we need to update it)
if "page_data" not in sample:
return sample
# Randomly choose a rotation (90, 180, or 270 degrees)
rotation_degrees = np.random.choice([90, 180, 270])
# Apply rotation to image
image = sample["image"]
if rotation_degrees == 90:
transpose = Image.Transpose.ROTATE_90
elif rotation_degrees == 180:
transpose = Image.Transpose.ROTATE_180
else: # 270
transpose = Image.Transpose.ROTATE_270
rotated_image = image.transpose(transpose)
sample["image"] = rotated_image
# Update page_data
page_data = sample["page_data"]
# Create new PageResponse with updated rotation info
# The rotation_correction should be the inverse of what we applied
# If we rotated 90 clockwise, we need 270 counter-clockwise to correct it
if rotation_degrees == 90:
correction = 270
elif rotation_degrees == 180:
correction = 180
else: # 270
correction = 90
from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse(
primary_language=page_data.primary_language,
is_rotation_valid=False, # Mark as invalid since we rotated it
rotation_correction=correction, # The correction needed to fix it
is_table=page_data.is_table,
is_diagram=page_data.is_diagram,
natural_text=page_data.natural_text,
)
sample["page_data"] = new_page_data
return sample
@dataclass(frozen=True, slots=True)
class FilterOutRotatedDocuments(PipelineStep):
"""Pipeline step that filters out documents with rotation issues."""
def __call__(self, sample: Sample) -> Optional[Sample]:
"""Filter out samples where rotation is invalid or rotation correction is needed."""
# Check if page_data exists
if "page_data" not in sample:
return sample
page_data = sample["page_data"]
# Check if page_data has the required attributes
if not hasattr(page_data, "is_rotation_valid") or not hasattr(page_data, "rotation_correction"):
return sample
# Filter out if rotation is invalid or rotation correction is not 0
if page_data.is_rotation_valid is False or page_data.rotation_correction != 0:
return None
return sample
@dataclass(frozen=True, slots=True)
class InstructUserMessages(PipelineStep):
"""Creates instruction-following messages format for training."""